##// END OF EJS Templates
code: code cleanups, use is None instead of == None.
marcink -
r3231:462664f1 default
parent child Browse files
Show More
@@ -1,781 +1,780 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2018 RhodeCode GmbH
3 # Copyright (C) 2010-2018 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 import logging
22 import logging
23 import collections
23 import collections
24
24
25 import datetime
25 import datetime
26 import formencode
26 import formencode
27 import formencode.htmlfill
27 import formencode.htmlfill
28
28
29 import rhodecode
29 import rhodecode
30 from pyramid.view import view_config
30 from pyramid.view import view_config
31 from pyramid.httpexceptions import HTTPFound, HTTPNotFound
31 from pyramid.httpexceptions import HTTPFound, HTTPNotFound
32 from pyramid.renderers import render
32 from pyramid.renderers import render
33 from pyramid.response import Response
33 from pyramid.response import Response
34
34
35 from rhodecode.apps._base import BaseAppView
35 from rhodecode.apps._base import BaseAppView
36 from rhodecode.apps.admin.navigation import navigation_list
36 from rhodecode.apps.admin.navigation import navigation_list
37 from rhodecode.apps.svn_support.config_keys import generate_config
37 from rhodecode.apps.svn_support.config_keys import generate_config
38 from rhodecode.lib import helpers as h
38 from rhodecode.lib import helpers as h
39 from rhodecode.lib.auth import (
39 from rhodecode.lib.auth import (
40 LoginRequired, HasPermissionAllDecorator, CSRFRequired)
40 LoginRequired, HasPermissionAllDecorator, CSRFRequired)
41 from rhodecode.lib.celerylib import tasks, run_task
41 from rhodecode.lib.celerylib import tasks, run_task
42 from rhodecode.lib.utils import repo2db_mapper
42 from rhodecode.lib.utils import repo2db_mapper
43 from rhodecode.lib.utils2 import str2bool, safe_unicode, AttributeDict
43 from rhodecode.lib.utils2 import str2bool, safe_unicode, AttributeDict
44 from rhodecode.lib.index import searcher_from_config
44 from rhodecode.lib.index import searcher_from_config
45
45
46 from rhodecode.model.db import RhodeCodeUi, Repository
46 from rhodecode.model.db import RhodeCodeUi, Repository
47 from rhodecode.model.forms import (ApplicationSettingsForm,
47 from rhodecode.model.forms import (ApplicationSettingsForm,
48 ApplicationUiSettingsForm, ApplicationVisualisationForm,
48 ApplicationUiSettingsForm, ApplicationVisualisationForm,
49 LabsSettingsForm, IssueTrackerPatternsForm)
49 LabsSettingsForm, IssueTrackerPatternsForm)
50 from rhodecode.model.repo_group import RepoGroupModel
50 from rhodecode.model.repo_group import RepoGroupModel
51
51
52 from rhodecode.model.scm import ScmModel
52 from rhodecode.model.scm import ScmModel
53 from rhodecode.model.notification import EmailNotificationModel
53 from rhodecode.model.notification import EmailNotificationModel
54 from rhodecode.model.meta import Session
54 from rhodecode.model.meta import Session
55 from rhodecode.model.settings import (
55 from rhodecode.model.settings import (
56 IssueTrackerSettingsModel, VcsSettingsModel, SettingNotFound,
56 IssueTrackerSettingsModel, VcsSettingsModel, SettingNotFound,
57 SettingsModel)
57 SettingsModel)
58
58
59
59
60 log = logging.getLogger(__name__)
60 log = logging.getLogger(__name__)
61
61
62
62
63 class AdminSettingsView(BaseAppView):
63 class AdminSettingsView(BaseAppView):
64
64
65 def load_default_context(self):
65 def load_default_context(self):
66 c = self._get_local_tmpl_context()
66 c = self._get_local_tmpl_context()
67 c.labs_active = str2bool(
67 c.labs_active = str2bool(
68 rhodecode.CONFIG.get('labs_settings_active', 'true'))
68 rhodecode.CONFIG.get('labs_settings_active', 'true'))
69 c.navlist = navigation_list(self.request)
69 c.navlist = navigation_list(self.request)
70
70
71 return c
71 return c
72
72
73 @classmethod
73 @classmethod
74 def _get_ui_settings(cls):
74 def _get_ui_settings(cls):
75 ret = RhodeCodeUi.query().all()
75 ret = RhodeCodeUi.query().all()
76
76
77 if not ret:
77 if not ret:
78 raise Exception('Could not get application ui settings !')
78 raise Exception('Could not get application ui settings !')
79 settings = {}
79 settings = {}
80 for each in ret:
80 for each in ret:
81 k = each.ui_key
81 k = each.ui_key
82 v = each.ui_value
82 v = each.ui_value
83 if k == '/':
83 if k == '/':
84 k = 'root_path'
84 k = 'root_path'
85
85
86 if k in ['push_ssl', 'publish', 'enabled']:
86 if k in ['push_ssl', 'publish', 'enabled']:
87 v = str2bool(v)
87 v = str2bool(v)
88
88
89 if k.find('.') != -1:
89 if k.find('.') != -1:
90 k = k.replace('.', '_')
90 k = k.replace('.', '_')
91
91
92 if each.ui_section in ['hooks', 'extensions']:
92 if each.ui_section in ['hooks', 'extensions']:
93 v = each.ui_active
93 v = each.ui_active
94
94
95 settings[each.ui_section + '_' + k] = v
95 settings[each.ui_section + '_' + k] = v
96 return settings
96 return settings
97
97
98 @classmethod
98 @classmethod
99 def _form_defaults(cls):
99 def _form_defaults(cls):
100 defaults = SettingsModel().get_all_settings()
100 defaults = SettingsModel().get_all_settings()
101 defaults.update(cls._get_ui_settings())
101 defaults.update(cls._get_ui_settings())
102
102
103 defaults.update({
103 defaults.update({
104 'new_svn_branch': '',
104 'new_svn_branch': '',
105 'new_svn_tag': '',
105 'new_svn_tag': '',
106 })
106 })
107 return defaults
107 return defaults
108
108
109 @LoginRequired()
109 @LoginRequired()
110 @HasPermissionAllDecorator('hg.admin')
110 @HasPermissionAllDecorator('hg.admin')
111 @view_config(
111 @view_config(
112 route_name='admin_settings_vcs', request_method='GET',
112 route_name='admin_settings_vcs', request_method='GET',
113 renderer='rhodecode:templates/admin/settings/settings.mako')
113 renderer='rhodecode:templates/admin/settings/settings.mako')
114 def settings_vcs(self):
114 def settings_vcs(self):
115 c = self.load_default_context()
115 c = self.load_default_context()
116 c.active = 'vcs'
116 c.active = 'vcs'
117 model = VcsSettingsModel()
117 model = VcsSettingsModel()
118 c.svn_branch_patterns = model.get_global_svn_branch_patterns()
118 c.svn_branch_patterns = model.get_global_svn_branch_patterns()
119 c.svn_tag_patterns = model.get_global_svn_tag_patterns()
119 c.svn_tag_patterns = model.get_global_svn_tag_patterns()
120
120
121 settings = self.request.registry.settings
121 settings = self.request.registry.settings
122 c.svn_proxy_generate_config = settings[generate_config]
122 c.svn_proxy_generate_config = settings[generate_config]
123
123
124 defaults = self._form_defaults()
124 defaults = self._form_defaults()
125
125
126 model.create_largeobjects_dirs_if_needed(defaults['paths_root_path'])
126 model.create_largeobjects_dirs_if_needed(defaults['paths_root_path'])
127
127
128 data = render('rhodecode:templates/admin/settings/settings.mako',
128 data = render('rhodecode:templates/admin/settings/settings.mako',
129 self._get_template_context(c), self.request)
129 self._get_template_context(c), self.request)
130 html = formencode.htmlfill.render(
130 html = formencode.htmlfill.render(
131 data,
131 data,
132 defaults=defaults,
132 defaults=defaults,
133 encoding="UTF-8",
133 encoding="UTF-8",
134 force_defaults=False
134 force_defaults=False
135 )
135 )
136 return Response(html)
136 return Response(html)
137
137
138 @LoginRequired()
138 @LoginRequired()
139 @HasPermissionAllDecorator('hg.admin')
139 @HasPermissionAllDecorator('hg.admin')
140 @CSRFRequired()
140 @CSRFRequired()
141 @view_config(
141 @view_config(
142 route_name='admin_settings_vcs_update', request_method='POST',
142 route_name='admin_settings_vcs_update', request_method='POST',
143 renderer='rhodecode:templates/admin/settings/settings.mako')
143 renderer='rhodecode:templates/admin/settings/settings.mako')
144 def settings_vcs_update(self):
144 def settings_vcs_update(self):
145 _ = self.request.translate
145 _ = self.request.translate
146 c = self.load_default_context()
146 c = self.load_default_context()
147 c.active = 'vcs'
147 c.active = 'vcs'
148
148
149 model = VcsSettingsModel()
149 model = VcsSettingsModel()
150 c.svn_branch_patterns = model.get_global_svn_branch_patterns()
150 c.svn_branch_patterns = model.get_global_svn_branch_patterns()
151 c.svn_tag_patterns = model.get_global_svn_tag_patterns()
151 c.svn_tag_patterns = model.get_global_svn_tag_patterns()
152
152
153 settings = self.request.registry.settings
153 settings = self.request.registry.settings
154 c.svn_proxy_generate_config = settings[generate_config]
154 c.svn_proxy_generate_config = settings[generate_config]
155
155
156 application_form = ApplicationUiSettingsForm(self.request.translate)()
156 application_form = ApplicationUiSettingsForm(self.request.translate)()
157
157
158 try:
158 try:
159 form_result = application_form.to_python(dict(self.request.POST))
159 form_result = application_form.to_python(dict(self.request.POST))
160 except formencode.Invalid as errors:
160 except formencode.Invalid as errors:
161 h.flash(
161 h.flash(
162 _("Some form inputs contain invalid data."),
162 _("Some form inputs contain invalid data."),
163 category='error')
163 category='error')
164 data = render('rhodecode:templates/admin/settings/settings.mako',
164 data = render('rhodecode:templates/admin/settings/settings.mako',
165 self._get_template_context(c), self.request)
165 self._get_template_context(c), self.request)
166 html = formencode.htmlfill.render(
166 html = formencode.htmlfill.render(
167 data,
167 data,
168 defaults=errors.value,
168 defaults=errors.value,
169 errors=errors.error_dict or {},
169 errors=errors.error_dict or {},
170 prefix_error=False,
170 prefix_error=False,
171 encoding="UTF-8",
171 encoding="UTF-8",
172 force_defaults=False
172 force_defaults=False
173 )
173 )
174 return Response(html)
174 return Response(html)
175
175
176 try:
176 try:
177 if c.visual.allow_repo_location_change:
177 if c.visual.allow_repo_location_change:
178 model.update_global_path_setting(
178 model.update_global_path_setting(form_result['paths_root_path'])
179 form_result['paths_root_path'])
180
179
181 model.update_global_ssl_setting(form_result['web_push_ssl'])
180 model.update_global_ssl_setting(form_result['web_push_ssl'])
182 model.update_global_hook_settings(form_result)
181 model.update_global_hook_settings(form_result)
183
182
184 model.create_or_update_global_svn_settings(form_result)
183 model.create_or_update_global_svn_settings(form_result)
185 model.create_or_update_global_hg_settings(form_result)
184 model.create_or_update_global_hg_settings(form_result)
186 model.create_or_update_global_git_settings(form_result)
185 model.create_or_update_global_git_settings(form_result)
187 model.create_or_update_global_pr_settings(form_result)
186 model.create_or_update_global_pr_settings(form_result)
188 except Exception:
187 except Exception:
189 log.exception("Exception while updating settings")
188 log.exception("Exception while updating settings")
190 h.flash(_('Error occurred during updating '
189 h.flash(_('Error occurred during updating '
191 'application settings'), category='error')
190 'application settings'), category='error')
192 else:
191 else:
193 Session().commit()
192 Session().commit()
194 h.flash(_('Updated VCS settings'), category='success')
193 h.flash(_('Updated VCS settings'), category='success')
195 raise HTTPFound(h.route_path('admin_settings_vcs'))
194 raise HTTPFound(h.route_path('admin_settings_vcs'))
196
195
197 data = render('rhodecode:templates/admin/settings/settings.mako',
196 data = render('rhodecode:templates/admin/settings/settings.mako',
198 self._get_template_context(c), self.request)
197 self._get_template_context(c), self.request)
199 html = formencode.htmlfill.render(
198 html = formencode.htmlfill.render(
200 data,
199 data,
201 defaults=self._form_defaults(),
200 defaults=self._form_defaults(),
202 encoding="UTF-8",
201 encoding="UTF-8",
203 force_defaults=False
202 force_defaults=False
204 )
203 )
205 return Response(html)
204 return Response(html)
206
205
207 @LoginRequired()
206 @LoginRequired()
208 @HasPermissionAllDecorator('hg.admin')
207 @HasPermissionAllDecorator('hg.admin')
209 @CSRFRequired()
208 @CSRFRequired()
210 @view_config(
209 @view_config(
211 route_name='admin_settings_vcs_svn_pattern_delete', request_method='POST',
210 route_name='admin_settings_vcs_svn_pattern_delete', request_method='POST',
212 renderer='json_ext', xhr=True)
211 renderer='json_ext', xhr=True)
213 def settings_vcs_delete_svn_pattern(self):
212 def settings_vcs_delete_svn_pattern(self):
214 delete_pattern_id = self.request.POST.get('delete_svn_pattern')
213 delete_pattern_id = self.request.POST.get('delete_svn_pattern')
215 model = VcsSettingsModel()
214 model = VcsSettingsModel()
216 try:
215 try:
217 model.delete_global_svn_pattern(delete_pattern_id)
216 model.delete_global_svn_pattern(delete_pattern_id)
218 except SettingNotFound:
217 except SettingNotFound:
219 log.exception(
218 log.exception(
220 'Failed to delete svn_pattern with id %s', delete_pattern_id)
219 'Failed to delete svn_pattern with id %s', delete_pattern_id)
221 raise HTTPNotFound()
220 raise HTTPNotFound()
222
221
223 Session().commit()
222 Session().commit()
224 return True
223 return True
225
224
226 @LoginRequired()
225 @LoginRequired()
227 @HasPermissionAllDecorator('hg.admin')
226 @HasPermissionAllDecorator('hg.admin')
228 @view_config(
227 @view_config(
229 route_name='admin_settings_mapping', request_method='GET',
228 route_name='admin_settings_mapping', request_method='GET',
230 renderer='rhodecode:templates/admin/settings/settings.mako')
229 renderer='rhodecode:templates/admin/settings/settings.mako')
231 def settings_mapping(self):
230 def settings_mapping(self):
232 c = self.load_default_context()
231 c = self.load_default_context()
233 c.active = 'mapping'
232 c.active = 'mapping'
234
233
235 data = render('rhodecode:templates/admin/settings/settings.mako',
234 data = render('rhodecode:templates/admin/settings/settings.mako',
236 self._get_template_context(c), self.request)
235 self._get_template_context(c), self.request)
237 html = formencode.htmlfill.render(
236 html = formencode.htmlfill.render(
238 data,
237 data,
239 defaults=self._form_defaults(),
238 defaults=self._form_defaults(),
240 encoding="UTF-8",
239 encoding="UTF-8",
241 force_defaults=False
240 force_defaults=False
242 )
241 )
243 return Response(html)
242 return Response(html)
244
243
245 @LoginRequired()
244 @LoginRequired()
246 @HasPermissionAllDecorator('hg.admin')
245 @HasPermissionAllDecorator('hg.admin')
247 @CSRFRequired()
246 @CSRFRequired()
248 @view_config(
247 @view_config(
249 route_name='admin_settings_mapping_update', request_method='POST',
248 route_name='admin_settings_mapping_update', request_method='POST',
250 renderer='rhodecode:templates/admin/settings/settings.mako')
249 renderer='rhodecode:templates/admin/settings/settings.mako')
251 def settings_mapping_update(self):
250 def settings_mapping_update(self):
252 _ = self.request.translate
251 _ = self.request.translate
253 c = self.load_default_context()
252 c = self.load_default_context()
254 c.active = 'mapping'
253 c.active = 'mapping'
255 rm_obsolete = self.request.POST.get('destroy', False)
254 rm_obsolete = self.request.POST.get('destroy', False)
256 invalidate_cache = self.request.POST.get('invalidate', False)
255 invalidate_cache = self.request.POST.get('invalidate', False)
257 log.debug(
256 log.debug(
258 'rescanning repo location with destroy obsolete=%s', rm_obsolete)
257 'rescanning repo location with destroy obsolete=%s', rm_obsolete)
259
258
260 if invalidate_cache:
259 if invalidate_cache:
261 log.debug('invalidating all repositories cache')
260 log.debug('invalidating all repositories cache')
262 for repo in Repository.get_all():
261 for repo in Repository.get_all():
263 ScmModel().mark_for_invalidation(repo.repo_name, delete=True)
262 ScmModel().mark_for_invalidation(repo.repo_name, delete=True)
264
263
265 filesystem_repos = ScmModel().repo_scan()
264 filesystem_repos = ScmModel().repo_scan()
266 added, removed = repo2db_mapper(filesystem_repos, rm_obsolete)
265 added, removed = repo2db_mapper(filesystem_repos, rm_obsolete)
267 _repr = lambda l: ', '.join(map(safe_unicode, l)) or '-'
266 _repr = lambda l: ', '.join(map(safe_unicode, l)) or '-'
268 h.flash(_('Repositories successfully '
267 h.flash(_('Repositories successfully '
269 'rescanned added: %s ; removed: %s') %
268 'rescanned added: %s ; removed: %s') %
270 (_repr(added), _repr(removed)),
269 (_repr(added), _repr(removed)),
271 category='success')
270 category='success')
272 raise HTTPFound(h.route_path('admin_settings_mapping'))
271 raise HTTPFound(h.route_path('admin_settings_mapping'))
273
272
274 @LoginRequired()
273 @LoginRequired()
275 @HasPermissionAllDecorator('hg.admin')
274 @HasPermissionAllDecorator('hg.admin')
276 @view_config(
275 @view_config(
277 route_name='admin_settings', request_method='GET',
276 route_name='admin_settings', request_method='GET',
278 renderer='rhodecode:templates/admin/settings/settings.mako')
277 renderer='rhodecode:templates/admin/settings/settings.mako')
279 @view_config(
278 @view_config(
280 route_name='admin_settings_global', request_method='GET',
279 route_name='admin_settings_global', request_method='GET',
281 renderer='rhodecode:templates/admin/settings/settings.mako')
280 renderer='rhodecode:templates/admin/settings/settings.mako')
282 def settings_global(self):
281 def settings_global(self):
283 c = self.load_default_context()
282 c = self.load_default_context()
284 c.active = 'global'
283 c.active = 'global'
285 c.personal_repo_group_default_pattern = RepoGroupModel()\
284 c.personal_repo_group_default_pattern = RepoGroupModel()\
286 .get_personal_group_name_pattern()
285 .get_personal_group_name_pattern()
287
286
288 data = render('rhodecode:templates/admin/settings/settings.mako',
287 data = render('rhodecode:templates/admin/settings/settings.mako',
289 self._get_template_context(c), self.request)
288 self._get_template_context(c), self.request)
290 html = formencode.htmlfill.render(
289 html = formencode.htmlfill.render(
291 data,
290 data,
292 defaults=self._form_defaults(),
291 defaults=self._form_defaults(),
293 encoding="UTF-8",
292 encoding="UTF-8",
294 force_defaults=False
293 force_defaults=False
295 )
294 )
296 return Response(html)
295 return Response(html)
297
296
298 @LoginRequired()
297 @LoginRequired()
299 @HasPermissionAllDecorator('hg.admin')
298 @HasPermissionAllDecorator('hg.admin')
300 @CSRFRequired()
299 @CSRFRequired()
301 @view_config(
300 @view_config(
302 route_name='admin_settings_update', request_method='POST',
301 route_name='admin_settings_update', request_method='POST',
303 renderer='rhodecode:templates/admin/settings/settings.mako')
302 renderer='rhodecode:templates/admin/settings/settings.mako')
304 @view_config(
303 @view_config(
305 route_name='admin_settings_global_update', request_method='POST',
304 route_name='admin_settings_global_update', request_method='POST',
306 renderer='rhodecode:templates/admin/settings/settings.mako')
305 renderer='rhodecode:templates/admin/settings/settings.mako')
307 def settings_global_update(self):
306 def settings_global_update(self):
308 _ = self.request.translate
307 _ = self.request.translate
309 c = self.load_default_context()
308 c = self.load_default_context()
310 c.active = 'global'
309 c.active = 'global'
311 c.personal_repo_group_default_pattern = RepoGroupModel()\
310 c.personal_repo_group_default_pattern = RepoGroupModel()\
312 .get_personal_group_name_pattern()
311 .get_personal_group_name_pattern()
313 application_form = ApplicationSettingsForm(self.request.translate)()
312 application_form = ApplicationSettingsForm(self.request.translate)()
314 try:
313 try:
315 form_result = application_form.to_python(dict(self.request.POST))
314 form_result = application_form.to_python(dict(self.request.POST))
316 except formencode.Invalid as errors:
315 except formencode.Invalid as errors:
317 h.flash(
316 h.flash(
318 _("Some form inputs contain invalid data."),
317 _("Some form inputs contain invalid data."),
319 category='error')
318 category='error')
320 data = render('rhodecode:templates/admin/settings/settings.mako',
319 data = render('rhodecode:templates/admin/settings/settings.mako',
321 self._get_template_context(c), self.request)
320 self._get_template_context(c), self.request)
322 html = formencode.htmlfill.render(
321 html = formencode.htmlfill.render(
323 data,
322 data,
324 defaults=errors.value,
323 defaults=errors.value,
325 errors=errors.error_dict or {},
324 errors=errors.error_dict or {},
326 prefix_error=False,
325 prefix_error=False,
327 encoding="UTF-8",
326 encoding="UTF-8",
328 force_defaults=False
327 force_defaults=False
329 )
328 )
330 return Response(html)
329 return Response(html)
331
330
332 settings = [
331 settings = [
333 ('title', 'rhodecode_title', 'unicode'),
332 ('title', 'rhodecode_title', 'unicode'),
334 ('realm', 'rhodecode_realm', 'unicode'),
333 ('realm', 'rhodecode_realm', 'unicode'),
335 ('pre_code', 'rhodecode_pre_code', 'unicode'),
334 ('pre_code', 'rhodecode_pre_code', 'unicode'),
336 ('post_code', 'rhodecode_post_code', 'unicode'),
335 ('post_code', 'rhodecode_post_code', 'unicode'),
337 ('captcha_public_key', 'rhodecode_captcha_public_key', 'unicode'),
336 ('captcha_public_key', 'rhodecode_captcha_public_key', 'unicode'),
338 ('captcha_private_key', 'rhodecode_captcha_private_key', 'unicode'),
337 ('captcha_private_key', 'rhodecode_captcha_private_key', 'unicode'),
339 ('create_personal_repo_group', 'rhodecode_create_personal_repo_group', 'bool'),
338 ('create_personal_repo_group', 'rhodecode_create_personal_repo_group', 'bool'),
340 ('personal_repo_group_pattern', 'rhodecode_personal_repo_group_pattern', 'unicode'),
339 ('personal_repo_group_pattern', 'rhodecode_personal_repo_group_pattern', 'unicode'),
341 ]
340 ]
342 try:
341 try:
343 for setting, form_key, type_ in settings:
342 for setting, form_key, type_ in settings:
344 sett = SettingsModel().create_or_update_setting(
343 sett = SettingsModel().create_or_update_setting(
345 setting, form_result[form_key], type_)
344 setting, form_result[form_key], type_)
346 Session().add(sett)
345 Session().add(sett)
347
346
348 Session().commit()
347 Session().commit()
349 SettingsModel().invalidate_settings_cache()
348 SettingsModel().invalidate_settings_cache()
350 h.flash(_('Updated application settings'), category='success')
349 h.flash(_('Updated application settings'), category='success')
351 except Exception:
350 except Exception:
352 log.exception("Exception while updating application settings")
351 log.exception("Exception while updating application settings")
353 h.flash(
352 h.flash(
354 _('Error occurred during updating application settings'),
353 _('Error occurred during updating application settings'),
355 category='error')
354 category='error')
356
355
357 raise HTTPFound(h.route_path('admin_settings_global'))
356 raise HTTPFound(h.route_path('admin_settings_global'))
358
357
359 @LoginRequired()
358 @LoginRequired()
360 @HasPermissionAllDecorator('hg.admin')
359 @HasPermissionAllDecorator('hg.admin')
361 @view_config(
360 @view_config(
362 route_name='admin_settings_visual', request_method='GET',
361 route_name='admin_settings_visual', request_method='GET',
363 renderer='rhodecode:templates/admin/settings/settings.mako')
362 renderer='rhodecode:templates/admin/settings/settings.mako')
364 def settings_visual(self):
363 def settings_visual(self):
365 c = self.load_default_context()
364 c = self.load_default_context()
366 c.active = 'visual'
365 c.active = 'visual'
367
366
368 data = render('rhodecode:templates/admin/settings/settings.mako',
367 data = render('rhodecode:templates/admin/settings/settings.mako',
369 self._get_template_context(c), self.request)
368 self._get_template_context(c), self.request)
370 html = formencode.htmlfill.render(
369 html = formencode.htmlfill.render(
371 data,
370 data,
372 defaults=self._form_defaults(),
371 defaults=self._form_defaults(),
373 encoding="UTF-8",
372 encoding="UTF-8",
374 force_defaults=False
373 force_defaults=False
375 )
374 )
376 return Response(html)
375 return Response(html)
377
376
378 @LoginRequired()
377 @LoginRequired()
379 @HasPermissionAllDecorator('hg.admin')
378 @HasPermissionAllDecorator('hg.admin')
380 @CSRFRequired()
379 @CSRFRequired()
381 @view_config(
380 @view_config(
382 route_name='admin_settings_visual_update', request_method='POST',
381 route_name='admin_settings_visual_update', request_method='POST',
383 renderer='rhodecode:templates/admin/settings/settings.mako')
382 renderer='rhodecode:templates/admin/settings/settings.mako')
384 def settings_visual_update(self):
383 def settings_visual_update(self):
385 _ = self.request.translate
384 _ = self.request.translate
386 c = self.load_default_context()
385 c = self.load_default_context()
387 c.active = 'visual'
386 c.active = 'visual'
388 application_form = ApplicationVisualisationForm(self.request.translate)()
387 application_form = ApplicationVisualisationForm(self.request.translate)()
389 try:
388 try:
390 form_result = application_form.to_python(dict(self.request.POST))
389 form_result = application_form.to_python(dict(self.request.POST))
391 except formencode.Invalid as errors:
390 except formencode.Invalid as errors:
392 h.flash(
391 h.flash(
393 _("Some form inputs contain invalid data."),
392 _("Some form inputs contain invalid data."),
394 category='error')
393 category='error')
395 data = render('rhodecode:templates/admin/settings/settings.mako',
394 data = render('rhodecode:templates/admin/settings/settings.mako',
396 self._get_template_context(c), self.request)
395 self._get_template_context(c), self.request)
397 html = formencode.htmlfill.render(
396 html = formencode.htmlfill.render(
398 data,
397 data,
399 defaults=errors.value,
398 defaults=errors.value,
400 errors=errors.error_dict or {},
399 errors=errors.error_dict or {},
401 prefix_error=False,
400 prefix_error=False,
402 encoding="UTF-8",
401 encoding="UTF-8",
403 force_defaults=False
402 force_defaults=False
404 )
403 )
405 return Response(html)
404 return Response(html)
406
405
407 try:
406 try:
408 settings = [
407 settings = [
409 ('show_public_icon', 'rhodecode_show_public_icon', 'bool'),
408 ('show_public_icon', 'rhodecode_show_public_icon', 'bool'),
410 ('show_private_icon', 'rhodecode_show_private_icon', 'bool'),
409 ('show_private_icon', 'rhodecode_show_private_icon', 'bool'),
411 ('stylify_metatags', 'rhodecode_stylify_metatags', 'bool'),
410 ('stylify_metatags', 'rhodecode_stylify_metatags', 'bool'),
412 ('repository_fields', 'rhodecode_repository_fields', 'bool'),
411 ('repository_fields', 'rhodecode_repository_fields', 'bool'),
413 ('dashboard_items', 'rhodecode_dashboard_items', 'int'),
412 ('dashboard_items', 'rhodecode_dashboard_items', 'int'),
414 ('admin_grid_items', 'rhodecode_admin_grid_items', 'int'),
413 ('admin_grid_items', 'rhodecode_admin_grid_items', 'int'),
415 ('show_version', 'rhodecode_show_version', 'bool'),
414 ('show_version', 'rhodecode_show_version', 'bool'),
416 ('use_gravatar', 'rhodecode_use_gravatar', 'bool'),
415 ('use_gravatar', 'rhodecode_use_gravatar', 'bool'),
417 ('markup_renderer', 'rhodecode_markup_renderer', 'unicode'),
416 ('markup_renderer', 'rhodecode_markup_renderer', 'unicode'),
418 ('gravatar_url', 'rhodecode_gravatar_url', 'unicode'),
417 ('gravatar_url', 'rhodecode_gravatar_url', 'unicode'),
419 ('clone_uri_tmpl', 'rhodecode_clone_uri_tmpl', 'unicode'),
418 ('clone_uri_tmpl', 'rhodecode_clone_uri_tmpl', 'unicode'),
420 ('clone_uri_ssh_tmpl', 'rhodecode_clone_uri_ssh_tmpl', 'unicode'),
419 ('clone_uri_ssh_tmpl', 'rhodecode_clone_uri_ssh_tmpl', 'unicode'),
421 ('support_url', 'rhodecode_support_url', 'unicode'),
420 ('support_url', 'rhodecode_support_url', 'unicode'),
422 ('show_revision_number', 'rhodecode_show_revision_number', 'bool'),
421 ('show_revision_number', 'rhodecode_show_revision_number', 'bool'),
423 ('show_sha_length', 'rhodecode_show_sha_length', 'int'),
422 ('show_sha_length', 'rhodecode_show_sha_length', 'int'),
424 ]
423 ]
425 for setting, form_key, type_ in settings:
424 for setting, form_key, type_ in settings:
426 sett = SettingsModel().create_or_update_setting(
425 sett = SettingsModel().create_or_update_setting(
427 setting, form_result[form_key], type_)
426 setting, form_result[form_key], type_)
428 Session().add(sett)
427 Session().add(sett)
429
428
430 Session().commit()
429 Session().commit()
431 SettingsModel().invalidate_settings_cache()
430 SettingsModel().invalidate_settings_cache()
432 h.flash(_('Updated visualisation settings'), category='success')
431 h.flash(_('Updated visualisation settings'), category='success')
433 except Exception:
432 except Exception:
434 log.exception("Exception updating visualization settings")
433 log.exception("Exception updating visualization settings")
435 h.flash(_('Error occurred during updating '
434 h.flash(_('Error occurred during updating '
436 'visualisation settings'),
435 'visualisation settings'),
437 category='error')
436 category='error')
438
437
439 raise HTTPFound(h.route_path('admin_settings_visual'))
438 raise HTTPFound(h.route_path('admin_settings_visual'))
440
439
441 @LoginRequired()
440 @LoginRequired()
442 @HasPermissionAllDecorator('hg.admin')
441 @HasPermissionAllDecorator('hg.admin')
443 @view_config(
442 @view_config(
444 route_name='admin_settings_issuetracker', request_method='GET',
443 route_name='admin_settings_issuetracker', request_method='GET',
445 renderer='rhodecode:templates/admin/settings/settings.mako')
444 renderer='rhodecode:templates/admin/settings/settings.mako')
446 def settings_issuetracker(self):
445 def settings_issuetracker(self):
447 c = self.load_default_context()
446 c = self.load_default_context()
448 c.active = 'issuetracker'
447 c.active = 'issuetracker'
449 defaults = SettingsModel().get_all_settings()
448 defaults = SettingsModel().get_all_settings()
450
449
451 entry_key = 'rhodecode_issuetracker_pat_'
450 entry_key = 'rhodecode_issuetracker_pat_'
452
451
453 c.issuetracker_entries = {}
452 c.issuetracker_entries = {}
454 for k, v in defaults.items():
453 for k, v in defaults.items():
455 if k.startswith(entry_key):
454 if k.startswith(entry_key):
456 uid = k[len(entry_key):]
455 uid = k[len(entry_key):]
457 c.issuetracker_entries[uid] = None
456 c.issuetracker_entries[uid] = None
458
457
459 for uid in c.issuetracker_entries:
458 for uid in c.issuetracker_entries:
460 c.issuetracker_entries[uid] = AttributeDict({
459 c.issuetracker_entries[uid] = AttributeDict({
461 'pat': defaults.get('rhodecode_issuetracker_pat_' + uid),
460 'pat': defaults.get('rhodecode_issuetracker_pat_' + uid),
462 'url': defaults.get('rhodecode_issuetracker_url_' + uid),
461 'url': defaults.get('rhodecode_issuetracker_url_' + uid),
463 'pref': defaults.get('rhodecode_issuetracker_pref_' + uid),
462 'pref': defaults.get('rhodecode_issuetracker_pref_' + uid),
464 'desc': defaults.get('rhodecode_issuetracker_desc_' + uid),
463 'desc': defaults.get('rhodecode_issuetracker_desc_' + uid),
465 })
464 })
466
465
467 return self._get_template_context(c)
466 return self._get_template_context(c)
468
467
469 @LoginRequired()
468 @LoginRequired()
470 @HasPermissionAllDecorator('hg.admin')
469 @HasPermissionAllDecorator('hg.admin')
471 @CSRFRequired()
470 @CSRFRequired()
472 @view_config(
471 @view_config(
473 route_name='admin_settings_issuetracker_test', request_method='POST',
472 route_name='admin_settings_issuetracker_test', request_method='POST',
474 renderer='string', xhr=True)
473 renderer='string', xhr=True)
475 def settings_issuetracker_test(self):
474 def settings_issuetracker_test(self):
476 return h.urlify_commit_message(
475 return h.urlify_commit_message(
477 self.request.POST.get('test_text', ''),
476 self.request.POST.get('test_text', ''),
478 'repo_group/test_repo1')
477 'repo_group/test_repo1')
479
478
480 @LoginRequired()
479 @LoginRequired()
481 @HasPermissionAllDecorator('hg.admin')
480 @HasPermissionAllDecorator('hg.admin')
482 @CSRFRequired()
481 @CSRFRequired()
483 @view_config(
482 @view_config(
484 route_name='admin_settings_issuetracker_update', request_method='POST',
483 route_name='admin_settings_issuetracker_update', request_method='POST',
485 renderer='rhodecode:templates/admin/settings/settings.mako')
484 renderer='rhodecode:templates/admin/settings/settings.mako')
486 def settings_issuetracker_update(self):
485 def settings_issuetracker_update(self):
487 _ = self.request.translate
486 _ = self.request.translate
488 self.load_default_context()
487 self.load_default_context()
489 settings_model = IssueTrackerSettingsModel()
488 settings_model = IssueTrackerSettingsModel()
490
489
491 try:
490 try:
492 form = IssueTrackerPatternsForm(self.request.translate)()
491 form = IssueTrackerPatternsForm(self.request.translate)()
493 data = form.to_python(self.request.POST)
492 data = form.to_python(self.request.POST)
494 except formencode.Invalid as errors:
493 except formencode.Invalid as errors:
495 log.exception('Failed to add new pattern')
494 log.exception('Failed to add new pattern')
496 error = errors
495 error = errors
497 h.flash(_('Invalid issue tracker pattern: {}'.format(error)),
496 h.flash(_('Invalid issue tracker pattern: {}'.format(error)),
498 category='error')
497 category='error')
499 raise HTTPFound(h.route_path('admin_settings_issuetracker'))
498 raise HTTPFound(h.route_path('admin_settings_issuetracker'))
500
499
501 if data:
500 if data:
502 for uid in data.get('delete_patterns', []):
501 for uid in data.get('delete_patterns', []):
503 settings_model.delete_entries(uid)
502 settings_model.delete_entries(uid)
504
503
505 for pattern in data.get('patterns', []):
504 for pattern in data.get('patterns', []):
506 for setting, value, type_ in pattern:
505 for setting, value, type_ in pattern:
507 sett = settings_model.create_or_update_setting(
506 sett = settings_model.create_or_update_setting(
508 setting, value, type_)
507 setting, value, type_)
509 Session().add(sett)
508 Session().add(sett)
510
509
511 Session().commit()
510 Session().commit()
512
511
513 SettingsModel().invalidate_settings_cache()
512 SettingsModel().invalidate_settings_cache()
514 h.flash(_('Updated issue tracker entries'), category='success')
513 h.flash(_('Updated issue tracker entries'), category='success')
515 raise HTTPFound(h.route_path('admin_settings_issuetracker'))
514 raise HTTPFound(h.route_path('admin_settings_issuetracker'))
516
515
517 @LoginRequired()
516 @LoginRequired()
518 @HasPermissionAllDecorator('hg.admin')
517 @HasPermissionAllDecorator('hg.admin')
519 @CSRFRequired()
518 @CSRFRequired()
520 @view_config(
519 @view_config(
521 route_name='admin_settings_issuetracker_delete', request_method='POST',
520 route_name='admin_settings_issuetracker_delete', request_method='POST',
522 renderer='rhodecode:templates/admin/settings/settings.mako')
521 renderer='rhodecode:templates/admin/settings/settings.mako')
523 def settings_issuetracker_delete(self):
522 def settings_issuetracker_delete(self):
524 _ = self.request.translate
523 _ = self.request.translate
525 self.load_default_context()
524 self.load_default_context()
526 uid = self.request.POST.get('uid')
525 uid = self.request.POST.get('uid')
527 try:
526 try:
528 IssueTrackerSettingsModel().delete_entries(uid)
527 IssueTrackerSettingsModel().delete_entries(uid)
529 except Exception:
528 except Exception:
530 log.exception('Failed to delete issue tracker setting %s', uid)
529 log.exception('Failed to delete issue tracker setting %s', uid)
531 raise HTTPNotFound()
530 raise HTTPNotFound()
532 h.flash(_('Removed issue tracker entry'), category='success')
531 h.flash(_('Removed issue tracker entry'), category='success')
533 raise HTTPFound(h.route_path('admin_settings_issuetracker'))
532 raise HTTPFound(h.route_path('admin_settings_issuetracker'))
534
533
535 @LoginRequired()
534 @LoginRequired()
536 @HasPermissionAllDecorator('hg.admin')
535 @HasPermissionAllDecorator('hg.admin')
537 @view_config(
536 @view_config(
538 route_name='admin_settings_email', request_method='GET',
537 route_name='admin_settings_email', request_method='GET',
539 renderer='rhodecode:templates/admin/settings/settings.mako')
538 renderer='rhodecode:templates/admin/settings/settings.mako')
540 def settings_email(self):
539 def settings_email(self):
541 c = self.load_default_context()
540 c = self.load_default_context()
542 c.active = 'email'
541 c.active = 'email'
543 c.rhodecode_ini = rhodecode.CONFIG
542 c.rhodecode_ini = rhodecode.CONFIG
544
543
545 data = render('rhodecode:templates/admin/settings/settings.mako',
544 data = render('rhodecode:templates/admin/settings/settings.mako',
546 self._get_template_context(c), self.request)
545 self._get_template_context(c), self.request)
547 html = formencode.htmlfill.render(
546 html = formencode.htmlfill.render(
548 data,
547 data,
549 defaults=self._form_defaults(),
548 defaults=self._form_defaults(),
550 encoding="UTF-8",
549 encoding="UTF-8",
551 force_defaults=False
550 force_defaults=False
552 )
551 )
553 return Response(html)
552 return Response(html)
554
553
555 @LoginRequired()
554 @LoginRequired()
556 @HasPermissionAllDecorator('hg.admin')
555 @HasPermissionAllDecorator('hg.admin')
557 @CSRFRequired()
556 @CSRFRequired()
558 @view_config(
557 @view_config(
559 route_name='admin_settings_email_update', request_method='POST',
558 route_name='admin_settings_email_update', request_method='POST',
560 renderer='rhodecode:templates/admin/settings/settings.mako')
559 renderer='rhodecode:templates/admin/settings/settings.mako')
561 def settings_email_update(self):
560 def settings_email_update(self):
562 _ = self.request.translate
561 _ = self.request.translate
563 c = self.load_default_context()
562 c = self.load_default_context()
564 c.active = 'email'
563 c.active = 'email'
565
564
566 test_email = self.request.POST.get('test_email')
565 test_email = self.request.POST.get('test_email')
567
566
568 if not test_email:
567 if not test_email:
569 h.flash(_('Please enter email address'), category='error')
568 h.flash(_('Please enter email address'), category='error')
570 raise HTTPFound(h.route_path('admin_settings_email'))
569 raise HTTPFound(h.route_path('admin_settings_email'))
571
570
572 email_kwargs = {
571 email_kwargs = {
573 'date': datetime.datetime.now(),
572 'date': datetime.datetime.now(),
574 'user': c.rhodecode_user,
573 'user': c.rhodecode_user,
575 'rhodecode_version': c.rhodecode_version
574 'rhodecode_version': c.rhodecode_version
576 }
575 }
577
576
578 (subject, headers, email_body,
577 (subject, headers, email_body,
579 email_body_plaintext) = EmailNotificationModel().render_email(
578 email_body_plaintext) = EmailNotificationModel().render_email(
580 EmailNotificationModel.TYPE_EMAIL_TEST, **email_kwargs)
579 EmailNotificationModel.TYPE_EMAIL_TEST, **email_kwargs)
581
580
582 recipients = [test_email] if test_email else None
581 recipients = [test_email] if test_email else None
583
582
584 run_task(tasks.send_email, recipients, subject,
583 run_task(tasks.send_email, recipients, subject,
585 email_body_plaintext, email_body)
584 email_body_plaintext, email_body)
586
585
587 h.flash(_('Send email task created'), category='success')
586 h.flash(_('Send email task created'), category='success')
588 raise HTTPFound(h.route_path('admin_settings_email'))
587 raise HTTPFound(h.route_path('admin_settings_email'))
589
588
590 @LoginRequired()
589 @LoginRequired()
591 @HasPermissionAllDecorator('hg.admin')
590 @HasPermissionAllDecorator('hg.admin')
592 @view_config(
591 @view_config(
593 route_name='admin_settings_hooks', request_method='GET',
592 route_name='admin_settings_hooks', request_method='GET',
594 renderer='rhodecode:templates/admin/settings/settings.mako')
593 renderer='rhodecode:templates/admin/settings/settings.mako')
595 def settings_hooks(self):
594 def settings_hooks(self):
596 c = self.load_default_context()
595 c = self.load_default_context()
597 c.active = 'hooks'
596 c.active = 'hooks'
598
597
599 model = SettingsModel()
598 model = SettingsModel()
600 c.hooks = model.get_builtin_hooks()
599 c.hooks = model.get_builtin_hooks()
601 c.custom_hooks = model.get_custom_hooks()
600 c.custom_hooks = model.get_custom_hooks()
602
601
603 data = render('rhodecode:templates/admin/settings/settings.mako',
602 data = render('rhodecode:templates/admin/settings/settings.mako',
604 self._get_template_context(c), self.request)
603 self._get_template_context(c), self.request)
605 html = formencode.htmlfill.render(
604 html = formencode.htmlfill.render(
606 data,
605 data,
607 defaults=self._form_defaults(),
606 defaults=self._form_defaults(),
608 encoding="UTF-8",
607 encoding="UTF-8",
609 force_defaults=False
608 force_defaults=False
610 )
609 )
611 return Response(html)
610 return Response(html)
612
611
613 @LoginRequired()
612 @LoginRequired()
614 @HasPermissionAllDecorator('hg.admin')
613 @HasPermissionAllDecorator('hg.admin')
615 @CSRFRequired()
614 @CSRFRequired()
616 @view_config(
615 @view_config(
617 route_name='admin_settings_hooks_update', request_method='POST',
616 route_name='admin_settings_hooks_update', request_method='POST',
618 renderer='rhodecode:templates/admin/settings/settings.mako')
617 renderer='rhodecode:templates/admin/settings/settings.mako')
619 @view_config(
618 @view_config(
620 route_name='admin_settings_hooks_delete', request_method='POST',
619 route_name='admin_settings_hooks_delete', request_method='POST',
621 renderer='rhodecode:templates/admin/settings/settings.mako')
620 renderer='rhodecode:templates/admin/settings/settings.mako')
622 def settings_hooks_update(self):
621 def settings_hooks_update(self):
623 _ = self.request.translate
622 _ = self.request.translate
624 c = self.load_default_context()
623 c = self.load_default_context()
625 c.active = 'hooks'
624 c.active = 'hooks'
626 if c.visual.allow_custom_hooks_settings:
625 if c.visual.allow_custom_hooks_settings:
627 ui_key = self.request.POST.get('new_hook_ui_key')
626 ui_key = self.request.POST.get('new_hook_ui_key')
628 ui_value = self.request.POST.get('new_hook_ui_value')
627 ui_value = self.request.POST.get('new_hook_ui_value')
629
628
630 hook_id = self.request.POST.get('hook_id')
629 hook_id = self.request.POST.get('hook_id')
631 new_hook = False
630 new_hook = False
632
631
633 model = SettingsModel()
632 model = SettingsModel()
634 try:
633 try:
635 if ui_value and ui_key:
634 if ui_value and ui_key:
636 model.create_or_update_hook(ui_key, ui_value)
635 model.create_or_update_hook(ui_key, ui_value)
637 h.flash(_('Added new hook'), category='success')
636 h.flash(_('Added new hook'), category='success')
638 new_hook = True
637 new_hook = True
639 elif hook_id:
638 elif hook_id:
640 RhodeCodeUi.delete(hook_id)
639 RhodeCodeUi.delete(hook_id)
641 Session().commit()
640 Session().commit()
642
641
643 # check for edits
642 # check for edits
644 update = False
643 update = False
645 _d = self.request.POST.dict_of_lists()
644 _d = self.request.POST.dict_of_lists()
646 for k, v in zip(_d.get('hook_ui_key', []),
645 for k, v in zip(_d.get('hook_ui_key', []),
647 _d.get('hook_ui_value_new', [])):
646 _d.get('hook_ui_value_new', [])):
648 model.create_or_update_hook(k, v)
647 model.create_or_update_hook(k, v)
649 update = True
648 update = True
650
649
651 if update and not new_hook:
650 if update and not new_hook:
652 h.flash(_('Updated hooks'), category='success')
651 h.flash(_('Updated hooks'), category='success')
653 Session().commit()
652 Session().commit()
654 except Exception:
653 except Exception:
655 log.exception("Exception during hook creation")
654 log.exception("Exception during hook creation")
656 h.flash(_('Error occurred during hook creation'),
655 h.flash(_('Error occurred during hook creation'),
657 category='error')
656 category='error')
658
657
659 raise HTTPFound(h.route_path('admin_settings_hooks'))
658 raise HTTPFound(h.route_path('admin_settings_hooks'))
660
659
661 @LoginRequired()
660 @LoginRequired()
662 @HasPermissionAllDecorator('hg.admin')
661 @HasPermissionAllDecorator('hg.admin')
663 @view_config(
662 @view_config(
664 route_name='admin_settings_search', request_method='GET',
663 route_name='admin_settings_search', request_method='GET',
665 renderer='rhodecode:templates/admin/settings/settings.mako')
664 renderer='rhodecode:templates/admin/settings/settings.mako')
666 def settings_search(self):
665 def settings_search(self):
667 c = self.load_default_context()
666 c = self.load_default_context()
668 c.active = 'search'
667 c.active = 'search'
669
668
670 searcher = searcher_from_config(self.request.registry.settings)
669 searcher = searcher_from_config(self.request.registry.settings)
671 c.statistics = searcher.statistics(self.request.translate)
670 c.statistics = searcher.statistics(self.request.translate)
672
671
673 return self._get_template_context(c)
672 return self._get_template_context(c)
674
673
675 @LoginRequired()
674 @LoginRequired()
676 @HasPermissionAllDecorator('hg.admin')
675 @HasPermissionAllDecorator('hg.admin')
677 @view_config(
676 @view_config(
678 route_name='admin_settings_automation', request_method='GET',
677 route_name='admin_settings_automation', request_method='GET',
679 renderer='rhodecode:templates/admin/settings/settings.mako')
678 renderer='rhodecode:templates/admin/settings/settings.mako')
680 def settings_automation(self):
679 def settings_automation(self):
681 c = self.load_default_context()
680 c = self.load_default_context()
682 c.active = 'automation'
681 c.active = 'automation'
683
682
684 return self._get_template_context(c)
683 return self._get_template_context(c)
685
684
686 @LoginRequired()
685 @LoginRequired()
687 @HasPermissionAllDecorator('hg.admin')
686 @HasPermissionAllDecorator('hg.admin')
688 @view_config(
687 @view_config(
689 route_name='admin_settings_labs', request_method='GET',
688 route_name='admin_settings_labs', request_method='GET',
690 renderer='rhodecode:templates/admin/settings/settings.mako')
689 renderer='rhodecode:templates/admin/settings/settings.mako')
691 def settings_labs(self):
690 def settings_labs(self):
692 c = self.load_default_context()
691 c = self.load_default_context()
693 if not c.labs_active:
692 if not c.labs_active:
694 raise HTTPFound(h.route_path('admin_settings'))
693 raise HTTPFound(h.route_path('admin_settings'))
695
694
696 c.active = 'labs'
695 c.active = 'labs'
697 c.lab_settings = _LAB_SETTINGS
696 c.lab_settings = _LAB_SETTINGS
698
697
699 data = render('rhodecode:templates/admin/settings/settings.mako',
698 data = render('rhodecode:templates/admin/settings/settings.mako',
700 self._get_template_context(c), self.request)
699 self._get_template_context(c), self.request)
701 html = formencode.htmlfill.render(
700 html = formencode.htmlfill.render(
702 data,
701 data,
703 defaults=self._form_defaults(),
702 defaults=self._form_defaults(),
704 encoding="UTF-8",
703 encoding="UTF-8",
705 force_defaults=False
704 force_defaults=False
706 )
705 )
707 return Response(html)
706 return Response(html)
708
707
709 @LoginRequired()
708 @LoginRequired()
710 @HasPermissionAllDecorator('hg.admin')
709 @HasPermissionAllDecorator('hg.admin')
711 @CSRFRequired()
710 @CSRFRequired()
712 @view_config(
711 @view_config(
713 route_name='admin_settings_labs_update', request_method='POST',
712 route_name='admin_settings_labs_update', request_method='POST',
714 renderer='rhodecode:templates/admin/settings/settings.mako')
713 renderer='rhodecode:templates/admin/settings/settings.mako')
715 def settings_labs_update(self):
714 def settings_labs_update(self):
716 _ = self.request.translate
715 _ = self.request.translate
717 c = self.load_default_context()
716 c = self.load_default_context()
718 c.active = 'labs'
717 c.active = 'labs'
719
718
720 application_form = LabsSettingsForm(self.request.translate)()
719 application_form = LabsSettingsForm(self.request.translate)()
721 try:
720 try:
722 form_result = application_form.to_python(dict(self.request.POST))
721 form_result = application_form.to_python(dict(self.request.POST))
723 except formencode.Invalid as errors:
722 except formencode.Invalid as errors:
724 h.flash(
723 h.flash(
725 _("Some form inputs contain invalid data."),
724 _("Some form inputs contain invalid data."),
726 category='error')
725 category='error')
727 data = render('rhodecode:templates/admin/settings/settings.mako',
726 data = render('rhodecode:templates/admin/settings/settings.mako',
728 self._get_template_context(c), self.request)
727 self._get_template_context(c), self.request)
729 html = formencode.htmlfill.render(
728 html = formencode.htmlfill.render(
730 data,
729 data,
731 defaults=errors.value,
730 defaults=errors.value,
732 errors=errors.error_dict or {},
731 errors=errors.error_dict or {},
733 prefix_error=False,
732 prefix_error=False,
734 encoding="UTF-8",
733 encoding="UTF-8",
735 force_defaults=False
734 force_defaults=False
736 )
735 )
737 return Response(html)
736 return Response(html)
738
737
739 try:
738 try:
740 session = Session()
739 session = Session()
741 for setting in _LAB_SETTINGS:
740 for setting in _LAB_SETTINGS:
742 setting_name = setting.key[len('rhodecode_'):]
741 setting_name = setting.key[len('rhodecode_'):]
743 sett = SettingsModel().create_or_update_setting(
742 sett = SettingsModel().create_or_update_setting(
744 setting_name, form_result[setting.key], setting.type)
743 setting_name, form_result[setting.key], setting.type)
745 session.add(sett)
744 session.add(sett)
746
745
747 except Exception:
746 except Exception:
748 log.exception('Exception while updating lab settings')
747 log.exception('Exception while updating lab settings')
749 h.flash(_('Error occurred during updating labs settings'),
748 h.flash(_('Error occurred during updating labs settings'),
750 category='error')
749 category='error')
751 else:
750 else:
752 Session().commit()
751 Session().commit()
753 SettingsModel().invalidate_settings_cache()
752 SettingsModel().invalidate_settings_cache()
754 h.flash(_('Updated Labs settings'), category='success')
753 h.flash(_('Updated Labs settings'), category='success')
755 raise HTTPFound(h.route_path('admin_settings_labs'))
754 raise HTTPFound(h.route_path('admin_settings_labs'))
756
755
757 data = render('rhodecode:templates/admin/settings/settings.mako',
756 data = render('rhodecode:templates/admin/settings/settings.mako',
758 self._get_template_context(c), self.request)
757 self._get_template_context(c), self.request)
759 html = formencode.htmlfill.render(
758 html = formencode.htmlfill.render(
760 data,
759 data,
761 defaults=self._form_defaults(),
760 defaults=self._form_defaults(),
762 encoding="UTF-8",
761 encoding="UTF-8",
763 force_defaults=False
762 force_defaults=False
764 )
763 )
765 return Response(html)
764 return Response(html)
766
765
767
766
768 # :param key: name of the setting including the 'rhodecode_' prefix
767 # :param key: name of the setting including the 'rhodecode_' prefix
769 # :param type: the RhodeCodeSetting type to use.
768 # :param type: the RhodeCodeSetting type to use.
770 # :param group: the i18ned group in which we should dispaly this setting
769 # :param group: the i18ned group in which we should dispaly this setting
771 # :param label: the i18ned label we should display for this setting
770 # :param label: the i18ned label we should display for this setting
772 # :param help: the i18ned help we should dispaly for this setting
771 # :param help: the i18ned help we should dispaly for this setting
773 LabSetting = collections.namedtuple(
772 LabSetting = collections.namedtuple(
774 'LabSetting', ('key', 'type', 'group', 'label', 'help'))
773 'LabSetting', ('key', 'type', 'group', 'label', 'help'))
775
774
776
775
777 # This list has to be kept in sync with the form
776 # This list has to be kept in sync with the form
778 # rhodecode.model.forms.LabsSettingsForm.
777 # rhodecode.model.forms.LabsSettingsForm.
779 _LAB_SETTINGS = [
778 _LAB_SETTINGS = [
780
779
781 ]
780 ]
@@ -1,1919 +1,1919 b''
1 #!/usr/bin/python2.4
1 #!/usr/bin/python2.4
2
2
3 from __future__ import division
3 from __future__ import division
4
4
5 """Diff Match and Patch
5 """Diff Match and Patch
6
6
7 Copyright 2006 Google Inc.
7 Copyright 2006 Google Inc.
8 http://code.google.com/p/google-diff-match-patch/
8 http://code.google.com/p/google-diff-match-patch/
9
9
10 Licensed under the Apache License, Version 2.0 (the "License");
10 Licensed under the Apache License, Version 2.0 (the "License");
11 you may not use this file except in compliance with the License.
11 you may not use this file except in compliance with the License.
12 You may obtain a copy of the License at
12 You may obtain a copy of the License at
13
13
14 http://www.apache.org/licenses/LICENSE-2.0
14 http://www.apache.org/licenses/LICENSE-2.0
15
15
16 Unless required by applicable law or agreed to in writing, software
16 Unless required by applicable law or agreed to in writing, software
17 distributed under the License is distributed on an "AS IS" BASIS,
17 distributed under the License is distributed on an "AS IS" BASIS,
18 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 See the License for the specific language governing permissions and
19 See the License for the specific language governing permissions and
20 limitations under the License.
20 limitations under the License.
21 """
21 """
22
22
23 """Functions for diff, match and patch.
23 """Functions for diff, match and patch.
24
24
25 Computes the difference between two texts to create a patch.
25 Computes the difference between two texts to create a patch.
26 Applies the patch onto another text, allowing for errors.
26 Applies the patch onto another text, allowing for errors.
27 """
27 """
28
28
29 __author__ = 'fraser@google.com (Neil Fraser)'
29 __author__ = 'fraser@google.com (Neil Fraser)'
30
30
31 import math
31 import math
32 import re
32 import re
33 import sys
33 import sys
34 import time
34 import time
35 import urllib
35 import urllib
36
36
37 class diff_match_patch:
37 class diff_match_patch:
38 """Class containing the diff, match and patch methods.
38 """Class containing the diff, match and patch methods.
39
39
40 Also contains the behaviour settings.
40 Also contains the behaviour settings.
41 """
41 """
42
42
43 def __init__(self):
43 def __init__(self):
44 """Inits a diff_match_patch object with default settings.
44 """Inits a diff_match_patch object with default settings.
45 Redefine these in your program to override the defaults.
45 Redefine these in your program to override the defaults.
46 """
46 """
47
47
48 # Number of seconds to map a diff before giving up (0 for infinity).
48 # Number of seconds to map a diff before giving up (0 for infinity).
49 self.Diff_Timeout = 1.0
49 self.Diff_Timeout = 1.0
50 # Cost of an empty edit operation in terms of edit characters.
50 # Cost of an empty edit operation in terms of edit characters.
51 self.Diff_EditCost = 4
51 self.Diff_EditCost = 4
52 # At what point is no match declared (0.0 = perfection, 1.0 = very loose).
52 # At what point is no match declared (0.0 = perfection, 1.0 = very loose).
53 self.Match_Threshold = 0.5
53 self.Match_Threshold = 0.5
54 # How far to search for a match (0 = exact location, 1000+ = broad match).
54 # How far to search for a match (0 = exact location, 1000+ = broad match).
55 # A match this many characters away from the expected location will add
55 # A match this many characters away from the expected location will add
56 # 1.0 to the score (0.0 is a perfect match).
56 # 1.0 to the score (0.0 is a perfect match).
57 self.Match_Distance = 1000
57 self.Match_Distance = 1000
58 # When deleting a large block of text (over ~64 characters), how close do
58 # When deleting a large block of text (over ~64 characters), how close do
59 # the contents have to be to match the expected contents. (0.0 = perfection,
59 # the contents have to be to match the expected contents. (0.0 = perfection,
60 # 1.0 = very loose). Note that Match_Threshold controls how closely the
60 # 1.0 = very loose). Note that Match_Threshold controls how closely the
61 # end points of a delete need to match.
61 # end points of a delete need to match.
62 self.Patch_DeleteThreshold = 0.5
62 self.Patch_DeleteThreshold = 0.5
63 # Chunk size for context length.
63 # Chunk size for context length.
64 self.Patch_Margin = 4
64 self.Patch_Margin = 4
65
65
66 # The number of bits in an int.
66 # The number of bits in an int.
67 # Python has no maximum, thus to disable patch splitting set to 0.
67 # Python has no maximum, thus to disable patch splitting set to 0.
68 # However to avoid long patches in certain pathological cases, use 32.
68 # However to avoid long patches in certain pathological cases, use 32.
69 # Multiple short patches (using native ints) are much faster than long ones.
69 # Multiple short patches (using native ints) are much faster than long ones.
70 self.Match_MaxBits = 32
70 self.Match_MaxBits = 32
71
71
72 # DIFF FUNCTIONS
72 # DIFF FUNCTIONS
73
73
74 # The data structure representing a diff is an array of tuples:
74 # The data structure representing a diff is an array of tuples:
75 # [(DIFF_DELETE, "Hello"), (DIFF_INSERT, "Goodbye"), (DIFF_EQUAL, " world.")]
75 # [(DIFF_DELETE, "Hello"), (DIFF_INSERT, "Goodbye"), (DIFF_EQUAL, " world.")]
76 # which means: delete "Hello", add "Goodbye" and keep " world."
76 # which means: delete "Hello", add "Goodbye" and keep " world."
77 DIFF_DELETE = -1
77 DIFF_DELETE = -1
78 DIFF_INSERT = 1
78 DIFF_INSERT = 1
79 DIFF_EQUAL = 0
79 DIFF_EQUAL = 0
80
80
81 def diff_main(self, text1, text2, checklines=True, deadline=None):
81 def diff_main(self, text1, text2, checklines=True, deadline=None):
82 """Find the differences between two texts. Simplifies the problem by
82 """Find the differences between two texts. Simplifies the problem by
83 stripping any common prefix or suffix off the texts before diffing.
83 stripping any common prefix or suffix off the texts before diffing.
84
84
85 Args:
85 Args:
86 text1: Old string to be diffed.
86 text1: Old string to be diffed.
87 text2: New string to be diffed.
87 text2: New string to be diffed.
88 checklines: Optional speedup flag. If present and false, then don't run
88 checklines: Optional speedup flag. If present and false, then don't run
89 a line-level diff first to identify the changed areas.
89 a line-level diff first to identify the changed areas.
90 Defaults to true, which does a faster, slightly less optimal diff.
90 Defaults to true, which does a faster, slightly less optimal diff.
91 deadline: Optional time when the diff should be complete by. Used
91 deadline: Optional time when the diff should be complete by. Used
92 internally for recursive calls. Users should set DiffTimeout instead.
92 internally for recursive calls. Users should set DiffTimeout instead.
93
93
94 Returns:
94 Returns:
95 Array of changes.
95 Array of changes.
96 """
96 """
97 # Set a deadline by which time the diff must be complete.
97 # Set a deadline by which time the diff must be complete.
98 if deadline == None:
98 if deadline is None:
99 # Unlike in most languages, Python counts time in seconds.
99 # Unlike in most languages, Python counts time in seconds.
100 if self.Diff_Timeout <= 0:
100 if self.Diff_Timeout <= 0:
101 deadline = sys.maxint
101 deadline = sys.maxint
102 else:
102 else:
103 deadline = time.time() + self.Diff_Timeout
103 deadline = time.time() + self.Diff_Timeout
104
104
105 # Check for null inputs.
105 # Check for null inputs.
106 if text1 == None or text2 == None:
106 if text1 is None or text2 is None:
107 raise ValueError("Null inputs. (diff_main)")
107 raise ValueError("Null inputs. (diff_main)")
108
108
109 # Check for equality (speedup).
109 # Check for equality (speedup).
110 if text1 == text2:
110 if text1 == text2:
111 if text1:
111 if text1:
112 return [(self.DIFF_EQUAL, text1)]
112 return [(self.DIFF_EQUAL, text1)]
113 return []
113 return []
114
114
115 # Trim off common prefix (speedup).
115 # Trim off common prefix (speedup).
116 commonlength = self.diff_commonPrefix(text1, text2)
116 commonlength = self.diff_commonPrefix(text1, text2)
117 commonprefix = text1[:commonlength]
117 commonprefix = text1[:commonlength]
118 text1 = text1[commonlength:]
118 text1 = text1[commonlength:]
119 text2 = text2[commonlength:]
119 text2 = text2[commonlength:]
120
120
121 # Trim off common suffix (speedup).
121 # Trim off common suffix (speedup).
122 commonlength = self.diff_commonSuffix(text1, text2)
122 commonlength = self.diff_commonSuffix(text1, text2)
123 if commonlength == 0:
123 if commonlength == 0:
124 commonsuffix = ''
124 commonsuffix = ''
125 else:
125 else:
126 commonsuffix = text1[-commonlength:]
126 commonsuffix = text1[-commonlength:]
127 text1 = text1[:-commonlength]
127 text1 = text1[:-commonlength]
128 text2 = text2[:-commonlength]
128 text2 = text2[:-commonlength]
129
129
130 # Compute the diff on the middle block.
130 # Compute the diff on the middle block.
131 diffs = self.diff_compute(text1, text2, checklines, deadline)
131 diffs = self.diff_compute(text1, text2, checklines, deadline)
132
132
133 # Restore the prefix and suffix.
133 # Restore the prefix and suffix.
134 if commonprefix:
134 if commonprefix:
135 diffs[:0] = [(self.DIFF_EQUAL, commonprefix)]
135 diffs[:0] = [(self.DIFF_EQUAL, commonprefix)]
136 if commonsuffix:
136 if commonsuffix:
137 diffs.append((self.DIFF_EQUAL, commonsuffix))
137 diffs.append((self.DIFF_EQUAL, commonsuffix))
138 self.diff_cleanupMerge(diffs)
138 self.diff_cleanupMerge(diffs)
139 return diffs
139 return diffs
140
140
141 def diff_compute(self, text1, text2, checklines, deadline):
141 def diff_compute(self, text1, text2, checklines, deadline):
142 """Find the differences between two texts. Assumes that the texts do not
142 """Find the differences between two texts. Assumes that the texts do not
143 have any common prefix or suffix.
143 have any common prefix or suffix.
144
144
145 Args:
145 Args:
146 text1: Old string to be diffed.
146 text1: Old string to be diffed.
147 text2: New string to be diffed.
147 text2: New string to be diffed.
148 checklines: Speedup flag. If false, then don't run a line-level diff
148 checklines: Speedup flag. If false, then don't run a line-level diff
149 first to identify the changed areas.
149 first to identify the changed areas.
150 If true, then run a faster, slightly less optimal diff.
150 If true, then run a faster, slightly less optimal diff.
151 deadline: Time when the diff should be complete by.
151 deadline: Time when the diff should be complete by.
152
152
153 Returns:
153 Returns:
154 Array of changes.
154 Array of changes.
155 """
155 """
156 if not text1:
156 if not text1:
157 # Just add some text (speedup).
157 # Just add some text (speedup).
158 return [(self.DIFF_INSERT, text2)]
158 return [(self.DIFF_INSERT, text2)]
159
159
160 if not text2:
160 if not text2:
161 # Just delete some text (speedup).
161 # Just delete some text (speedup).
162 return [(self.DIFF_DELETE, text1)]
162 return [(self.DIFF_DELETE, text1)]
163
163
164 if len(text1) > len(text2):
164 if len(text1) > len(text2):
165 (longtext, shorttext) = (text1, text2)
165 (longtext, shorttext) = (text1, text2)
166 else:
166 else:
167 (shorttext, longtext) = (text1, text2)
167 (shorttext, longtext) = (text1, text2)
168 i = longtext.find(shorttext)
168 i = longtext.find(shorttext)
169 if i != -1:
169 if i != -1:
170 # Shorter text is inside the longer text (speedup).
170 # Shorter text is inside the longer text (speedup).
171 diffs = [(self.DIFF_INSERT, longtext[:i]), (self.DIFF_EQUAL, shorttext),
171 diffs = [(self.DIFF_INSERT, longtext[:i]), (self.DIFF_EQUAL, shorttext),
172 (self.DIFF_INSERT, longtext[i + len(shorttext):])]
172 (self.DIFF_INSERT, longtext[i + len(shorttext):])]
173 # Swap insertions for deletions if diff is reversed.
173 # Swap insertions for deletions if diff is reversed.
174 if len(text1) > len(text2):
174 if len(text1) > len(text2):
175 diffs[0] = (self.DIFF_DELETE, diffs[0][1])
175 diffs[0] = (self.DIFF_DELETE, diffs[0][1])
176 diffs[2] = (self.DIFF_DELETE, diffs[2][1])
176 diffs[2] = (self.DIFF_DELETE, diffs[2][1])
177 return diffs
177 return diffs
178
178
179 if len(shorttext) == 1:
179 if len(shorttext) == 1:
180 # Single character string.
180 # Single character string.
181 # After the previous speedup, the character can't be an equality.
181 # After the previous speedup, the character can't be an equality.
182 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
182 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
183
183
184 # Check to see if the problem can be split in two.
184 # Check to see if the problem can be split in two.
185 hm = self.diff_halfMatch(text1, text2)
185 hm = self.diff_halfMatch(text1, text2)
186 if hm:
186 if hm:
187 # A half-match was found, sort out the return data.
187 # A half-match was found, sort out the return data.
188 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
188 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
189 # Send both pairs off for separate processing.
189 # Send both pairs off for separate processing.
190 diffs_a = self.diff_main(text1_a, text2_a, checklines, deadline)
190 diffs_a = self.diff_main(text1_a, text2_a, checklines, deadline)
191 diffs_b = self.diff_main(text1_b, text2_b, checklines, deadline)
191 diffs_b = self.diff_main(text1_b, text2_b, checklines, deadline)
192 # Merge the results.
192 # Merge the results.
193 return diffs_a + [(self.DIFF_EQUAL, mid_common)] + diffs_b
193 return diffs_a + [(self.DIFF_EQUAL, mid_common)] + diffs_b
194
194
195 if checklines and len(text1) > 100 and len(text2) > 100:
195 if checklines and len(text1) > 100 and len(text2) > 100:
196 return self.diff_lineMode(text1, text2, deadline)
196 return self.diff_lineMode(text1, text2, deadline)
197
197
198 return self.diff_bisect(text1, text2, deadline)
198 return self.diff_bisect(text1, text2, deadline)
199
199
200 def diff_lineMode(self, text1, text2, deadline):
200 def diff_lineMode(self, text1, text2, deadline):
201 """Do a quick line-level diff on both strings, then rediff the parts for
201 """Do a quick line-level diff on both strings, then rediff the parts for
202 greater accuracy.
202 greater accuracy.
203 This speedup can produce non-minimal diffs.
203 This speedup can produce non-minimal diffs.
204
204
205 Args:
205 Args:
206 text1: Old string to be diffed.
206 text1: Old string to be diffed.
207 text2: New string to be diffed.
207 text2: New string to be diffed.
208 deadline: Time when the diff should be complete by.
208 deadline: Time when the diff should be complete by.
209
209
210 Returns:
210 Returns:
211 Array of changes.
211 Array of changes.
212 """
212 """
213
213
214 # Scan the text on a line-by-line basis first.
214 # Scan the text on a line-by-line basis first.
215 (text1, text2, linearray) = self.diff_linesToChars(text1, text2)
215 (text1, text2, linearray) = self.diff_linesToChars(text1, text2)
216
216
217 diffs = self.diff_main(text1, text2, False, deadline)
217 diffs = self.diff_main(text1, text2, False, deadline)
218
218
219 # Convert the diff back to original text.
219 # Convert the diff back to original text.
220 self.diff_charsToLines(diffs, linearray)
220 self.diff_charsToLines(diffs, linearray)
221 # Eliminate freak matches (e.g. blank lines)
221 # Eliminate freak matches (e.g. blank lines)
222 self.diff_cleanupSemantic(diffs)
222 self.diff_cleanupSemantic(diffs)
223
223
224 # Rediff any replacement blocks, this time character-by-character.
224 # Rediff any replacement blocks, this time character-by-character.
225 # Add a dummy entry at the end.
225 # Add a dummy entry at the end.
226 diffs.append((self.DIFF_EQUAL, ''))
226 diffs.append((self.DIFF_EQUAL, ''))
227 pointer = 0
227 pointer = 0
228 count_delete = 0
228 count_delete = 0
229 count_insert = 0
229 count_insert = 0
230 text_delete = ''
230 text_delete = ''
231 text_insert = ''
231 text_insert = ''
232 while pointer < len(diffs):
232 while pointer < len(diffs):
233 if diffs[pointer][0] == self.DIFF_INSERT:
233 if diffs[pointer][0] == self.DIFF_INSERT:
234 count_insert += 1
234 count_insert += 1
235 text_insert += diffs[pointer][1]
235 text_insert += diffs[pointer][1]
236 elif diffs[pointer][0] == self.DIFF_DELETE:
236 elif diffs[pointer][0] == self.DIFF_DELETE:
237 count_delete += 1
237 count_delete += 1
238 text_delete += diffs[pointer][1]
238 text_delete += diffs[pointer][1]
239 elif diffs[pointer][0] == self.DIFF_EQUAL:
239 elif diffs[pointer][0] == self.DIFF_EQUAL:
240 # Upon reaching an equality, check for prior redundancies.
240 # Upon reaching an equality, check for prior redundancies.
241 if count_delete >= 1 and count_insert >= 1:
241 if count_delete >= 1 and count_insert >= 1:
242 # Delete the offending records and add the merged ones.
242 # Delete the offending records and add the merged ones.
243 a = self.diff_main(text_delete, text_insert, False, deadline)
243 a = self.diff_main(text_delete, text_insert, False, deadline)
244 diffs[pointer - count_delete - count_insert : pointer] = a
244 diffs[pointer - count_delete - count_insert : pointer] = a
245 pointer = pointer - count_delete - count_insert + len(a)
245 pointer = pointer - count_delete - count_insert + len(a)
246 count_insert = 0
246 count_insert = 0
247 count_delete = 0
247 count_delete = 0
248 text_delete = ''
248 text_delete = ''
249 text_insert = ''
249 text_insert = ''
250
250
251 pointer += 1
251 pointer += 1
252
252
253 diffs.pop() # Remove the dummy entry at the end.
253 diffs.pop() # Remove the dummy entry at the end.
254
254
255 return diffs
255 return diffs
256
256
257 def diff_bisect(self, text1, text2, deadline):
257 def diff_bisect(self, text1, text2, deadline):
258 """Find the 'middle snake' of a diff, split the problem in two
258 """Find the 'middle snake' of a diff, split the problem in two
259 and return the recursively constructed diff.
259 and return the recursively constructed diff.
260 See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations.
260 See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations.
261
261
262 Args:
262 Args:
263 text1: Old string to be diffed.
263 text1: Old string to be diffed.
264 text2: New string to be diffed.
264 text2: New string to be diffed.
265 deadline: Time at which to bail if not yet complete.
265 deadline: Time at which to bail if not yet complete.
266
266
267 Returns:
267 Returns:
268 Array of diff tuples.
268 Array of diff tuples.
269 """
269 """
270
270
271 # Cache the text lengths to prevent multiple calls.
271 # Cache the text lengths to prevent multiple calls.
272 text1_length = len(text1)
272 text1_length = len(text1)
273 text2_length = len(text2)
273 text2_length = len(text2)
274 max_d = (text1_length + text2_length + 1) // 2
274 max_d = (text1_length + text2_length + 1) // 2
275 v_offset = max_d
275 v_offset = max_d
276 v_length = 2 * max_d
276 v_length = 2 * max_d
277 v1 = [-1] * v_length
277 v1 = [-1] * v_length
278 v1[v_offset + 1] = 0
278 v1[v_offset + 1] = 0
279 v2 = v1[:]
279 v2 = v1[:]
280 delta = text1_length - text2_length
280 delta = text1_length - text2_length
281 # If the total number of characters is odd, then the front path will
281 # If the total number of characters is odd, then the front path will
282 # collide with the reverse path.
282 # collide with the reverse path.
283 front = (delta % 2 != 0)
283 front = (delta % 2 != 0)
284 # Offsets for start and end of k loop.
284 # Offsets for start and end of k loop.
285 # Prevents mapping of space beyond the grid.
285 # Prevents mapping of space beyond the grid.
286 k1start = 0
286 k1start = 0
287 k1end = 0
287 k1end = 0
288 k2start = 0
288 k2start = 0
289 k2end = 0
289 k2end = 0
290 for d in xrange(max_d):
290 for d in xrange(max_d):
291 # Bail out if deadline is reached.
291 # Bail out if deadline is reached.
292 if time.time() > deadline:
292 if time.time() > deadline:
293 break
293 break
294
294
295 # Walk the front path one step.
295 # Walk the front path one step.
296 for k1 in xrange(-d + k1start, d + 1 - k1end, 2):
296 for k1 in xrange(-d + k1start, d + 1 - k1end, 2):
297 k1_offset = v_offset + k1
297 k1_offset = v_offset + k1
298 if k1 == -d or (k1 != d and
298 if k1 == -d or (k1 != d and
299 v1[k1_offset - 1] < v1[k1_offset + 1]):
299 v1[k1_offset - 1] < v1[k1_offset + 1]):
300 x1 = v1[k1_offset + 1]
300 x1 = v1[k1_offset + 1]
301 else:
301 else:
302 x1 = v1[k1_offset - 1] + 1
302 x1 = v1[k1_offset - 1] + 1
303 y1 = x1 - k1
303 y1 = x1 - k1
304 while (x1 < text1_length and y1 < text2_length and
304 while (x1 < text1_length and y1 < text2_length and
305 text1[x1] == text2[y1]):
305 text1[x1] == text2[y1]):
306 x1 += 1
306 x1 += 1
307 y1 += 1
307 y1 += 1
308 v1[k1_offset] = x1
308 v1[k1_offset] = x1
309 if x1 > text1_length:
309 if x1 > text1_length:
310 # Ran off the right of the graph.
310 # Ran off the right of the graph.
311 k1end += 2
311 k1end += 2
312 elif y1 > text2_length:
312 elif y1 > text2_length:
313 # Ran off the bottom of the graph.
313 # Ran off the bottom of the graph.
314 k1start += 2
314 k1start += 2
315 elif front:
315 elif front:
316 k2_offset = v_offset + delta - k1
316 k2_offset = v_offset + delta - k1
317 if k2_offset >= 0 and k2_offset < v_length and v2[k2_offset] != -1:
317 if k2_offset >= 0 and k2_offset < v_length and v2[k2_offset] != -1:
318 # Mirror x2 onto top-left coordinate system.
318 # Mirror x2 onto top-left coordinate system.
319 x2 = text1_length - v2[k2_offset]
319 x2 = text1_length - v2[k2_offset]
320 if x1 >= x2:
320 if x1 >= x2:
321 # Overlap detected.
321 # Overlap detected.
322 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
322 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
323
323
324 # Walk the reverse path one step.
324 # Walk the reverse path one step.
325 for k2 in xrange(-d + k2start, d + 1 - k2end, 2):
325 for k2 in xrange(-d + k2start, d + 1 - k2end, 2):
326 k2_offset = v_offset + k2
326 k2_offset = v_offset + k2
327 if k2 == -d or (k2 != d and
327 if k2 == -d or (k2 != d and
328 v2[k2_offset - 1] < v2[k2_offset + 1]):
328 v2[k2_offset - 1] < v2[k2_offset + 1]):
329 x2 = v2[k2_offset + 1]
329 x2 = v2[k2_offset + 1]
330 else:
330 else:
331 x2 = v2[k2_offset - 1] + 1
331 x2 = v2[k2_offset - 1] + 1
332 y2 = x2 - k2
332 y2 = x2 - k2
333 while (x2 < text1_length and y2 < text2_length and
333 while (x2 < text1_length and y2 < text2_length and
334 text1[-x2 - 1] == text2[-y2 - 1]):
334 text1[-x2 - 1] == text2[-y2 - 1]):
335 x2 += 1
335 x2 += 1
336 y2 += 1
336 y2 += 1
337 v2[k2_offset] = x2
337 v2[k2_offset] = x2
338 if x2 > text1_length:
338 if x2 > text1_length:
339 # Ran off the left of the graph.
339 # Ran off the left of the graph.
340 k2end += 2
340 k2end += 2
341 elif y2 > text2_length:
341 elif y2 > text2_length:
342 # Ran off the top of the graph.
342 # Ran off the top of the graph.
343 k2start += 2
343 k2start += 2
344 elif not front:
344 elif not front:
345 k1_offset = v_offset + delta - k2
345 k1_offset = v_offset + delta - k2
346 if k1_offset >= 0 and k1_offset < v_length and v1[k1_offset] != -1:
346 if k1_offset >= 0 and k1_offset < v_length and v1[k1_offset] != -1:
347 x1 = v1[k1_offset]
347 x1 = v1[k1_offset]
348 y1 = v_offset + x1 - k1_offset
348 y1 = v_offset + x1 - k1_offset
349 # Mirror x2 onto top-left coordinate system.
349 # Mirror x2 onto top-left coordinate system.
350 x2 = text1_length - x2
350 x2 = text1_length - x2
351 if x1 >= x2:
351 if x1 >= x2:
352 # Overlap detected.
352 # Overlap detected.
353 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
353 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
354
354
355 # Diff took too long and hit the deadline or
355 # Diff took too long and hit the deadline or
356 # number of diffs equals number of characters, no commonality at all.
356 # number of diffs equals number of characters, no commonality at all.
357 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
357 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
358
358
359 def diff_bisectSplit(self, text1, text2, x, y, deadline):
359 def diff_bisectSplit(self, text1, text2, x, y, deadline):
360 """Given the location of the 'middle snake', split the diff in two parts
360 """Given the location of the 'middle snake', split the diff in two parts
361 and recurse.
361 and recurse.
362
362
363 Args:
363 Args:
364 text1: Old string to be diffed.
364 text1: Old string to be diffed.
365 text2: New string to be diffed.
365 text2: New string to be diffed.
366 x: Index of split point in text1.
366 x: Index of split point in text1.
367 y: Index of split point in text2.
367 y: Index of split point in text2.
368 deadline: Time at which to bail if not yet complete.
368 deadline: Time at which to bail if not yet complete.
369
369
370 Returns:
370 Returns:
371 Array of diff tuples.
371 Array of diff tuples.
372 """
372 """
373 text1a = text1[:x]
373 text1a = text1[:x]
374 text2a = text2[:y]
374 text2a = text2[:y]
375 text1b = text1[x:]
375 text1b = text1[x:]
376 text2b = text2[y:]
376 text2b = text2[y:]
377
377
378 # Compute both diffs serially.
378 # Compute both diffs serially.
379 diffs = self.diff_main(text1a, text2a, False, deadline)
379 diffs = self.diff_main(text1a, text2a, False, deadline)
380 diffsb = self.diff_main(text1b, text2b, False, deadline)
380 diffsb = self.diff_main(text1b, text2b, False, deadline)
381
381
382 return diffs + diffsb
382 return diffs + diffsb
383
383
384 def diff_linesToChars(self, text1, text2):
384 def diff_linesToChars(self, text1, text2):
385 """Split two texts into an array of strings. Reduce the texts to a string
385 """Split two texts into an array of strings. Reduce the texts to a string
386 of hashes where each Unicode character represents one line.
386 of hashes where each Unicode character represents one line.
387
387
388 Args:
388 Args:
389 text1: First string.
389 text1: First string.
390 text2: Second string.
390 text2: Second string.
391
391
392 Returns:
392 Returns:
393 Three element tuple, containing the encoded text1, the encoded text2 and
393 Three element tuple, containing the encoded text1, the encoded text2 and
394 the array of unique strings. The zeroth element of the array of unique
394 the array of unique strings. The zeroth element of the array of unique
395 strings is intentionally blank.
395 strings is intentionally blank.
396 """
396 """
397 lineArray = [] # e.g. lineArray[4] == "Hello\n"
397 lineArray = [] # e.g. lineArray[4] == "Hello\n"
398 lineHash = {} # e.g. lineHash["Hello\n"] == 4
398 lineHash = {} # e.g. lineHash["Hello\n"] == 4
399
399
400 # "\x00" is a valid character, but various debuggers don't like it.
400 # "\x00" is a valid character, but various debuggers don't like it.
401 # So we'll insert a junk entry to avoid generating a null character.
401 # So we'll insert a junk entry to avoid generating a null character.
402 lineArray.append('')
402 lineArray.append('')
403
403
404 def diff_linesToCharsMunge(text):
404 def diff_linesToCharsMunge(text):
405 """Split a text into an array of strings. Reduce the texts to a string
405 """Split a text into an array of strings. Reduce the texts to a string
406 of hashes where each Unicode character represents one line.
406 of hashes where each Unicode character represents one line.
407 Modifies linearray and linehash through being a closure.
407 Modifies linearray and linehash through being a closure.
408
408
409 Args:
409 Args:
410 text: String to encode.
410 text: String to encode.
411
411
412 Returns:
412 Returns:
413 Encoded string.
413 Encoded string.
414 """
414 """
415 chars = []
415 chars = []
416 # Walk the text, pulling out a substring for each line.
416 # Walk the text, pulling out a substring for each line.
417 # text.split('\n') would would temporarily double our memory footprint.
417 # text.split('\n') would would temporarily double our memory footprint.
418 # Modifying text would create many large strings to garbage collect.
418 # Modifying text would create many large strings to garbage collect.
419 lineStart = 0
419 lineStart = 0
420 lineEnd = -1
420 lineEnd = -1
421 while lineEnd < len(text) - 1:
421 while lineEnd < len(text) - 1:
422 lineEnd = text.find('\n', lineStart)
422 lineEnd = text.find('\n', lineStart)
423 if lineEnd == -1:
423 if lineEnd == -1:
424 lineEnd = len(text) - 1
424 lineEnd = len(text) - 1
425 line = text[lineStart:lineEnd + 1]
425 line = text[lineStart:lineEnd + 1]
426 lineStart = lineEnd + 1
426 lineStart = lineEnd + 1
427
427
428 if line in lineHash:
428 if line in lineHash:
429 chars.append(unichr(lineHash[line]))
429 chars.append(unichr(lineHash[line]))
430 else:
430 else:
431 lineArray.append(line)
431 lineArray.append(line)
432 lineHash[line] = len(lineArray) - 1
432 lineHash[line] = len(lineArray) - 1
433 chars.append(unichr(len(lineArray) - 1))
433 chars.append(unichr(len(lineArray) - 1))
434 return "".join(chars)
434 return "".join(chars)
435
435
436 chars1 = diff_linesToCharsMunge(text1)
436 chars1 = diff_linesToCharsMunge(text1)
437 chars2 = diff_linesToCharsMunge(text2)
437 chars2 = diff_linesToCharsMunge(text2)
438 return (chars1, chars2, lineArray)
438 return (chars1, chars2, lineArray)
439
439
440 def diff_charsToLines(self, diffs, lineArray):
440 def diff_charsToLines(self, diffs, lineArray):
441 """Rehydrate the text in a diff from a string of line hashes to real lines
441 """Rehydrate the text in a diff from a string of line hashes to real lines
442 of text.
442 of text.
443
443
444 Args:
444 Args:
445 diffs: Array of diff tuples.
445 diffs: Array of diff tuples.
446 lineArray: Array of unique strings.
446 lineArray: Array of unique strings.
447 """
447 """
448 for x in xrange(len(diffs)):
448 for x in xrange(len(diffs)):
449 text = []
449 text = []
450 for char in diffs[x][1]:
450 for char in diffs[x][1]:
451 text.append(lineArray[ord(char)])
451 text.append(lineArray[ord(char)])
452 diffs[x] = (diffs[x][0], "".join(text))
452 diffs[x] = (diffs[x][0], "".join(text))
453
453
454 def diff_commonPrefix(self, text1, text2):
454 def diff_commonPrefix(self, text1, text2):
455 """Determine the common prefix of two strings.
455 """Determine the common prefix of two strings.
456
456
457 Args:
457 Args:
458 text1: First string.
458 text1: First string.
459 text2: Second string.
459 text2: Second string.
460
460
461 Returns:
461 Returns:
462 The number of characters common to the start of each string.
462 The number of characters common to the start of each string.
463 """
463 """
464 # Quick check for common null cases.
464 # Quick check for common null cases.
465 if not text1 or not text2 or text1[0] != text2[0]:
465 if not text1 or not text2 or text1[0] != text2[0]:
466 return 0
466 return 0
467 # Binary search.
467 # Binary search.
468 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
468 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
469 pointermin = 0
469 pointermin = 0
470 pointermax = min(len(text1), len(text2))
470 pointermax = min(len(text1), len(text2))
471 pointermid = pointermax
471 pointermid = pointermax
472 pointerstart = 0
472 pointerstart = 0
473 while pointermin < pointermid:
473 while pointermin < pointermid:
474 if text1[pointerstart:pointermid] == text2[pointerstart:pointermid]:
474 if text1[pointerstart:pointermid] == text2[pointerstart:pointermid]:
475 pointermin = pointermid
475 pointermin = pointermid
476 pointerstart = pointermin
476 pointerstart = pointermin
477 else:
477 else:
478 pointermax = pointermid
478 pointermax = pointermid
479 pointermid = (pointermax - pointermin) // 2 + pointermin
479 pointermid = (pointermax - pointermin) // 2 + pointermin
480 return pointermid
480 return pointermid
481
481
482 def diff_commonSuffix(self, text1, text2):
482 def diff_commonSuffix(self, text1, text2):
483 """Determine the common suffix of two strings.
483 """Determine the common suffix of two strings.
484
484
485 Args:
485 Args:
486 text1: First string.
486 text1: First string.
487 text2: Second string.
487 text2: Second string.
488
488
489 Returns:
489 Returns:
490 The number of characters common to the end of each string.
490 The number of characters common to the end of each string.
491 """
491 """
492 # Quick check for common null cases.
492 # Quick check for common null cases.
493 if not text1 or not text2 or text1[-1] != text2[-1]:
493 if not text1 or not text2 or text1[-1] != text2[-1]:
494 return 0
494 return 0
495 # Binary search.
495 # Binary search.
496 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
496 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
497 pointermin = 0
497 pointermin = 0
498 pointermax = min(len(text1), len(text2))
498 pointermax = min(len(text1), len(text2))
499 pointermid = pointermax
499 pointermid = pointermax
500 pointerend = 0
500 pointerend = 0
501 while pointermin < pointermid:
501 while pointermin < pointermid:
502 if (text1[-pointermid:len(text1) - pointerend] ==
502 if (text1[-pointermid:len(text1) - pointerend] ==
503 text2[-pointermid:len(text2) - pointerend]):
503 text2[-pointermid:len(text2) - pointerend]):
504 pointermin = pointermid
504 pointermin = pointermid
505 pointerend = pointermin
505 pointerend = pointermin
506 else:
506 else:
507 pointermax = pointermid
507 pointermax = pointermid
508 pointermid = (pointermax - pointermin) // 2 + pointermin
508 pointermid = (pointermax - pointermin) // 2 + pointermin
509 return pointermid
509 return pointermid
510
510
511 def diff_commonOverlap(self, text1, text2):
511 def diff_commonOverlap(self, text1, text2):
512 """Determine if the suffix of one string is the prefix of another.
512 """Determine if the suffix of one string is the prefix of another.
513
513
514 Args:
514 Args:
515 text1 First string.
515 text1 First string.
516 text2 Second string.
516 text2 Second string.
517
517
518 Returns:
518 Returns:
519 The number of characters common to the end of the first
519 The number of characters common to the end of the first
520 string and the start of the second string.
520 string and the start of the second string.
521 """
521 """
522 # Cache the text lengths to prevent multiple calls.
522 # Cache the text lengths to prevent multiple calls.
523 text1_length = len(text1)
523 text1_length = len(text1)
524 text2_length = len(text2)
524 text2_length = len(text2)
525 # Eliminate the null case.
525 # Eliminate the null case.
526 if text1_length == 0 or text2_length == 0:
526 if text1_length == 0 or text2_length == 0:
527 return 0
527 return 0
528 # Truncate the longer string.
528 # Truncate the longer string.
529 if text1_length > text2_length:
529 if text1_length > text2_length:
530 text1 = text1[-text2_length:]
530 text1 = text1[-text2_length:]
531 elif text1_length < text2_length:
531 elif text1_length < text2_length:
532 text2 = text2[:text1_length]
532 text2 = text2[:text1_length]
533 text_length = min(text1_length, text2_length)
533 text_length = min(text1_length, text2_length)
534 # Quick check for the worst case.
534 # Quick check for the worst case.
535 if text1 == text2:
535 if text1 == text2:
536 return text_length
536 return text_length
537
537
538 # Start by looking for a single character match
538 # Start by looking for a single character match
539 # and increase length until no match is found.
539 # and increase length until no match is found.
540 # Performance analysis: http://neil.fraser.name/news/2010/11/04/
540 # Performance analysis: http://neil.fraser.name/news/2010/11/04/
541 best = 0
541 best = 0
542 length = 1
542 length = 1
543 while True:
543 while True:
544 pattern = text1[-length:]
544 pattern = text1[-length:]
545 found = text2.find(pattern)
545 found = text2.find(pattern)
546 if found == -1:
546 if found == -1:
547 return best
547 return best
548 length += found
548 length += found
549 if found == 0 or text1[-length:] == text2[:length]:
549 if found == 0 or text1[-length:] == text2[:length]:
550 best = length
550 best = length
551 length += 1
551 length += 1
552
552
553 def diff_halfMatch(self, text1, text2):
553 def diff_halfMatch(self, text1, text2):
554 """Do the two texts share a substring which is at least half the length of
554 """Do the two texts share a substring which is at least half the length of
555 the longer text?
555 the longer text?
556 This speedup can produce non-minimal diffs.
556 This speedup can produce non-minimal diffs.
557
557
558 Args:
558 Args:
559 text1: First string.
559 text1: First string.
560 text2: Second string.
560 text2: Second string.
561
561
562 Returns:
562 Returns:
563 Five element Array, containing the prefix of text1, the suffix of text1,
563 Five element Array, containing the prefix of text1, the suffix of text1,
564 the prefix of text2, the suffix of text2 and the common middle. Or None
564 the prefix of text2, the suffix of text2 and the common middle. Or None
565 if there was no match.
565 if there was no match.
566 """
566 """
567 if self.Diff_Timeout <= 0:
567 if self.Diff_Timeout <= 0:
568 # Don't risk returning a non-optimal diff if we have unlimited time.
568 # Don't risk returning a non-optimal diff if we have unlimited time.
569 return None
569 return None
570 if len(text1) > len(text2):
570 if len(text1) > len(text2):
571 (longtext, shorttext) = (text1, text2)
571 (longtext, shorttext) = (text1, text2)
572 else:
572 else:
573 (shorttext, longtext) = (text1, text2)
573 (shorttext, longtext) = (text1, text2)
574 if len(longtext) < 4 or len(shorttext) * 2 < len(longtext):
574 if len(longtext) < 4 or len(shorttext) * 2 < len(longtext):
575 return None # Pointless.
575 return None # Pointless.
576
576
577 def diff_halfMatchI(longtext, shorttext, i):
577 def diff_halfMatchI(longtext, shorttext, i):
578 """Does a substring of shorttext exist within longtext such that the
578 """Does a substring of shorttext exist within longtext such that the
579 substring is at least half the length of longtext?
579 substring is at least half the length of longtext?
580 Closure, but does not reference any external variables.
580 Closure, but does not reference any external variables.
581
581
582 Args:
582 Args:
583 longtext: Longer string.
583 longtext: Longer string.
584 shorttext: Shorter string.
584 shorttext: Shorter string.
585 i: Start index of quarter length substring within longtext.
585 i: Start index of quarter length substring within longtext.
586
586
587 Returns:
587 Returns:
588 Five element Array, containing the prefix of longtext, the suffix of
588 Five element Array, containing the prefix of longtext, the suffix of
589 longtext, the prefix of shorttext, the suffix of shorttext and the
589 longtext, the prefix of shorttext, the suffix of shorttext and the
590 common middle. Or None if there was no match.
590 common middle. Or None if there was no match.
591 """
591 """
592 seed = longtext[i:i + len(longtext) // 4]
592 seed = longtext[i:i + len(longtext) // 4]
593 best_common = ''
593 best_common = ''
594 j = shorttext.find(seed)
594 j = shorttext.find(seed)
595 while j != -1:
595 while j != -1:
596 prefixLength = self.diff_commonPrefix(longtext[i:], shorttext[j:])
596 prefixLength = self.diff_commonPrefix(longtext[i:], shorttext[j:])
597 suffixLength = self.diff_commonSuffix(longtext[:i], shorttext[:j])
597 suffixLength = self.diff_commonSuffix(longtext[:i], shorttext[:j])
598 if len(best_common) < suffixLength + prefixLength:
598 if len(best_common) < suffixLength + prefixLength:
599 best_common = (shorttext[j - suffixLength:j] +
599 best_common = (shorttext[j - suffixLength:j] +
600 shorttext[j:j + prefixLength])
600 shorttext[j:j + prefixLength])
601 best_longtext_a = longtext[:i - suffixLength]
601 best_longtext_a = longtext[:i - suffixLength]
602 best_longtext_b = longtext[i + prefixLength:]
602 best_longtext_b = longtext[i + prefixLength:]
603 best_shorttext_a = shorttext[:j - suffixLength]
603 best_shorttext_a = shorttext[:j - suffixLength]
604 best_shorttext_b = shorttext[j + prefixLength:]
604 best_shorttext_b = shorttext[j + prefixLength:]
605 j = shorttext.find(seed, j + 1)
605 j = shorttext.find(seed, j + 1)
606
606
607 if len(best_common) * 2 >= len(longtext):
607 if len(best_common) * 2 >= len(longtext):
608 return (best_longtext_a, best_longtext_b,
608 return (best_longtext_a, best_longtext_b,
609 best_shorttext_a, best_shorttext_b, best_common)
609 best_shorttext_a, best_shorttext_b, best_common)
610 else:
610 else:
611 return None
611 return None
612
612
613 # First check if the second quarter is the seed for a half-match.
613 # First check if the second quarter is the seed for a half-match.
614 hm1 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 3) // 4)
614 hm1 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 3) // 4)
615 # Check again based on the third quarter.
615 # Check again based on the third quarter.
616 hm2 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 1) // 2)
616 hm2 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 1) // 2)
617 if not hm1 and not hm2:
617 if not hm1 and not hm2:
618 return None
618 return None
619 elif not hm2:
619 elif not hm2:
620 hm = hm1
620 hm = hm1
621 elif not hm1:
621 elif not hm1:
622 hm = hm2
622 hm = hm2
623 else:
623 else:
624 # Both matched. Select the longest.
624 # Both matched. Select the longest.
625 if len(hm1[4]) > len(hm2[4]):
625 if len(hm1[4]) > len(hm2[4]):
626 hm = hm1
626 hm = hm1
627 else:
627 else:
628 hm = hm2
628 hm = hm2
629
629
630 # A half-match was found, sort out the return data.
630 # A half-match was found, sort out the return data.
631 if len(text1) > len(text2):
631 if len(text1) > len(text2):
632 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
632 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
633 else:
633 else:
634 (text2_a, text2_b, text1_a, text1_b, mid_common) = hm
634 (text2_a, text2_b, text1_a, text1_b, mid_common) = hm
635 return (text1_a, text1_b, text2_a, text2_b, mid_common)
635 return (text1_a, text1_b, text2_a, text2_b, mid_common)
636
636
637 def diff_cleanupSemantic(self, diffs):
637 def diff_cleanupSemantic(self, diffs):
638 """Reduce the number of edits by eliminating semantically trivial
638 """Reduce the number of edits by eliminating semantically trivial
639 equalities.
639 equalities.
640
640
641 Args:
641 Args:
642 diffs: Array of diff tuples.
642 diffs: Array of diff tuples.
643 """
643 """
644 changes = False
644 changes = False
645 equalities = [] # Stack of indices where equalities are found.
645 equalities = [] # Stack of indices where equalities are found.
646 lastequality = None # Always equal to diffs[equalities[-1]][1]
646 lastequality = None # Always equal to diffs[equalities[-1]][1]
647 pointer = 0 # Index of current position.
647 pointer = 0 # Index of current position.
648 # Number of chars that changed prior to the equality.
648 # Number of chars that changed prior to the equality.
649 length_insertions1, length_deletions1 = 0, 0
649 length_insertions1, length_deletions1 = 0, 0
650 # Number of chars that changed after the equality.
650 # Number of chars that changed after the equality.
651 length_insertions2, length_deletions2 = 0, 0
651 length_insertions2, length_deletions2 = 0, 0
652 while pointer < len(diffs):
652 while pointer < len(diffs):
653 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
653 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
654 equalities.append(pointer)
654 equalities.append(pointer)
655 length_insertions1, length_insertions2 = length_insertions2, 0
655 length_insertions1, length_insertions2 = length_insertions2, 0
656 length_deletions1, length_deletions2 = length_deletions2, 0
656 length_deletions1, length_deletions2 = length_deletions2, 0
657 lastequality = diffs[pointer][1]
657 lastequality = diffs[pointer][1]
658 else: # An insertion or deletion.
658 else: # An insertion or deletion.
659 if diffs[pointer][0] == self.DIFF_INSERT:
659 if diffs[pointer][0] == self.DIFF_INSERT:
660 length_insertions2 += len(diffs[pointer][1])
660 length_insertions2 += len(diffs[pointer][1])
661 else:
661 else:
662 length_deletions2 += len(diffs[pointer][1])
662 length_deletions2 += len(diffs[pointer][1])
663 # Eliminate an equality that is smaller or equal to the edits on both
663 # Eliminate an equality that is smaller or equal to the edits on both
664 # sides of it.
664 # sides of it.
665 if (lastequality and (len(lastequality) <=
665 if (lastequality and (len(lastequality) <=
666 max(length_insertions1, length_deletions1)) and
666 max(length_insertions1, length_deletions1)) and
667 (len(lastequality) <= max(length_insertions2, length_deletions2))):
667 (len(lastequality) <= max(length_insertions2, length_deletions2))):
668 # Duplicate record.
668 # Duplicate record.
669 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
669 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
670 # Change second copy to insert.
670 # Change second copy to insert.
671 diffs[equalities[-1] + 1] = (self.DIFF_INSERT,
671 diffs[equalities[-1] + 1] = (self.DIFF_INSERT,
672 diffs[equalities[-1] + 1][1])
672 diffs[equalities[-1] + 1][1])
673 # Throw away the equality we just deleted.
673 # Throw away the equality we just deleted.
674 equalities.pop()
674 equalities.pop()
675 # Throw away the previous equality (it needs to be reevaluated).
675 # Throw away the previous equality (it needs to be reevaluated).
676 if len(equalities):
676 if len(equalities):
677 equalities.pop()
677 equalities.pop()
678 if len(equalities):
678 if len(equalities):
679 pointer = equalities[-1]
679 pointer = equalities[-1]
680 else:
680 else:
681 pointer = -1
681 pointer = -1
682 # Reset the counters.
682 # Reset the counters.
683 length_insertions1, length_deletions1 = 0, 0
683 length_insertions1, length_deletions1 = 0, 0
684 length_insertions2, length_deletions2 = 0, 0
684 length_insertions2, length_deletions2 = 0, 0
685 lastequality = None
685 lastequality = None
686 changes = True
686 changes = True
687 pointer += 1
687 pointer += 1
688
688
689 # Normalize the diff.
689 # Normalize the diff.
690 if changes:
690 if changes:
691 self.diff_cleanupMerge(diffs)
691 self.diff_cleanupMerge(diffs)
692 self.diff_cleanupSemanticLossless(diffs)
692 self.diff_cleanupSemanticLossless(diffs)
693
693
694 # Find any overlaps between deletions and insertions.
694 # Find any overlaps between deletions and insertions.
695 # e.g: <del>abcxxx</del><ins>xxxdef</ins>
695 # e.g: <del>abcxxx</del><ins>xxxdef</ins>
696 # -> <del>abc</del>xxx<ins>def</ins>
696 # -> <del>abc</del>xxx<ins>def</ins>
697 # e.g: <del>xxxabc</del><ins>defxxx</ins>
697 # e.g: <del>xxxabc</del><ins>defxxx</ins>
698 # -> <ins>def</ins>xxx<del>abc</del>
698 # -> <ins>def</ins>xxx<del>abc</del>
699 # Only extract an overlap if it is as big as the edit ahead or behind it.
699 # Only extract an overlap if it is as big as the edit ahead or behind it.
700 pointer = 1
700 pointer = 1
701 while pointer < len(diffs):
701 while pointer < len(diffs):
702 if (diffs[pointer - 1][0] == self.DIFF_DELETE and
702 if (diffs[pointer - 1][0] == self.DIFF_DELETE and
703 diffs[pointer][0] == self.DIFF_INSERT):
703 diffs[pointer][0] == self.DIFF_INSERT):
704 deletion = diffs[pointer - 1][1]
704 deletion = diffs[pointer - 1][1]
705 insertion = diffs[pointer][1]
705 insertion = diffs[pointer][1]
706 overlap_length1 = self.diff_commonOverlap(deletion, insertion)
706 overlap_length1 = self.diff_commonOverlap(deletion, insertion)
707 overlap_length2 = self.diff_commonOverlap(insertion, deletion)
707 overlap_length2 = self.diff_commonOverlap(insertion, deletion)
708 if overlap_length1 >= overlap_length2:
708 if overlap_length1 >= overlap_length2:
709 if (overlap_length1 >= len(deletion) / 2.0 or
709 if (overlap_length1 >= len(deletion) / 2.0 or
710 overlap_length1 >= len(insertion) / 2.0):
710 overlap_length1 >= len(insertion) / 2.0):
711 # Overlap found. Insert an equality and trim the surrounding edits.
711 # Overlap found. Insert an equality and trim the surrounding edits.
712 diffs.insert(pointer, (self.DIFF_EQUAL,
712 diffs.insert(pointer, (self.DIFF_EQUAL,
713 insertion[:overlap_length1]))
713 insertion[:overlap_length1]))
714 diffs[pointer - 1] = (self.DIFF_DELETE,
714 diffs[pointer - 1] = (self.DIFF_DELETE,
715 deletion[:len(deletion) - overlap_length1])
715 deletion[:len(deletion) - overlap_length1])
716 diffs[pointer + 1] = (self.DIFF_INSERT,
716 diffs[pointer + 1] = (self.DIFF_INSERT,
717 insertion[overlap_length1:])
717 insertion[overlap_length1:])
718 pointer += 1
718 pointer += 1
719 else:
719 else:
720 if (overlap_length2 >= len(deletion) / 2.0 or
720 if (overlap_length2 >= len(deletion) / 2.0 or
721 overlap_length2 >= len(insertion) / 2.0):
721 overlap_length2 >= len(insertion) / 2.0):
722 # Reverse overlap found.
722 # Reverse overlap found.
723 # Insert an equality and swap and trim the surrounding edits.
723 # Insert an equality and swap and trim the surrounding edits.
724 diffs.insert(pointer, (self.DIFF_EQUAL, deletion[:overlap_length2]))
724 diffs.insert(pointer, (self.DIFF_EQUAL, deletion[:overlap_length2]))
725 diffs[pointer - 1] = (self.DIFF_INSERT,
725 diffs[pointer - 1] = (self.DIFF_INSERT,
726 insertion[:len(insertion) - overlap_length2])
726 insertion[:len(insertion) - overlap_length2])
727 diffs[pointer + 1] = (self.DIFF_DELETE, deletion[overlap_length2:])
727 diffs[pointer + 1] = (self.DIFF_DELETE, deletion[overlap_length2:])
728 pointer += 1
728 pointer += 1
729 pointer += 1
729 pointer += 1
730 pointer += 1
730 pointer += 1
731
731
732 def diff_cleanupSemanticLossless(self, diffs):
732 def diff_cleanupSemanticLossless(self, diffs):
733 """Look for single edits surrounded on both sides by equalities
733 """Look for single edits surrounded on both sides by equalities
734 which can be shifted sideways to align the edit to a word boundary.
734 which can be shifted sideways to align the edit to a word boundary.
735 e.g: The c<ins>at c</ins>ame. -> The <ins>cat </ins>came.
735 e.g: The c<ins>at c</ins>ame. -> The <ins>cat </ins>came.
736
736
737 Args:
737 Args:
738 diffs: Array of diff tuples.
738 diffs: Array of diff tuples.
739 """
739 """
740
740
741 def diff_cleanupSemanticScore(one, two):
741 def diff_cleanupSemanticScore(one, two):
742 """Given two strings, compute a score representing whether the
742 """Given two strings, compute a score representing whether the
743 internal boundary falls on logical boundaries.
743 internal boundary falls on logical boundaries.
744 Scores range from 6 (best) to 0 (worst).
744 Scores range from 6 (best) to 0 (worst).
745 Closure, but does not reference any external variables.
745 Closure, but does not reference any external variables.
746
746
747 Args:
747 Args:
748 one: First string.
748 one: First string.
749 two: Second string.
749 two: Second string.
750
750
751 Returns:
751 Returns:
752 The score.
752 The score.
753 """
753 """
754 if not one or not two:
754 if not one or not two:
755 # Edges are the best.
755 # Edges are the best.
756 return 6
756 return 6
757
757
758 # Each port of this function behaves slightly differently due to
758 # Each port of this function behaves slightly differently due to
759 # subtle differences in each language's definition of things like
759 # subtle differences in each language's definition of things like
760 # 'whitespace'. Since this function's purpose is largely cosmetic,
760 # 'whitespace'. Since this function's purpose is largely cosmetic,
761 # the choice has been made to use each language's native features
761 # the choice has been made to use each language's native features
762 # rather than force total conformity.
762 # rather than force total conformity.
763 char1 = one[-1]
763 char1 = one[-1]
764 char2 = two[0]
764 char2 = two[0]
765 nonAlphaNumeric1 = not char1.isalnum()
765 nonAlphaNumeric1 = not char1.isalnum()
766 nonAlphaNumeric2 = not char2.isalnum()
766 nonAlphaNumeric2 = not char2.isalnum()
767 whitespace1 = nonAlphaNumeric1 and char1.isspace()
767 whitespace1 = nonAlphaNumeric1 and char1.isspace()
768 whitespace2 = nonAlphaNumeric2 and char2.isspace()
768 whitespace2 = nonAlphaNumeric2 and char2.isspace()
769 lineBreak1 = whitespace1 and (char1 == "\r" or char1 == "\n")
769 lineBreak1 = whitespace1 and (char1 == "\r" or char1 == "\n")
770 lineBreak2 = whitespace2 and (char2 == "\r" or char2 == "\n")
770 lineBreak2 = whitespace2 and (char2 == "\r" or char2 == "\n")
771 blankLine1 = lineBreak1 and self.BLANKLINEEND.search(one)
771 blankLine1 = lineBreak1 and self.BLANKLINEEND.search(one)
772 blankLine2 = lineBreak2 and self.BLANKLINESTART.match(two)
772 blankLine2 = lineBreak2 and self.BLANKLINESTART.match(two)
773
773
774 if blankLine1 or blankLine2:
774 if blankLine1 or blankLine2:
775 # Five points for blank lines.
775 # Five points for blank lines.
776 return 5
776 return 5
777 elif lineBreak1 or lineBreak2:
777 elif lineBreak1 or lineBreak2:
778 # Four points for line breaks.
778 # Four points for line breaks.
779 return 4
779 return 4
780 elif nonAlphaNumeric1 and not whitespace1 and whitespace2:
780 elif nonAlphaNumeric1 and not whitespace1 and whitespace2:
781 # Three points for end of sentences.
781 # Three points for end of sentences.
782 return 3
782 return 3
783 elif whitespace1 or whitespace2:
783 elif whitespace1 or whitespace2:
784 # Two points for whitespace.
784 # Two points for whitespace.
785 return 2
785 return 2
786 elif nonAlphaNumeric1 or nonAlphaNumeric2:
786 elif nonAlphaNumeric1 or nonAlphaNumeric2:
787 # One point for non-alphanumeric.
787 # One point for non-alphanumeric.
788 return 1
788 return 1
789 return 0
789 return 0
790
790
791 pointer = 1
791 pointer = 1
792 # Intentionally ignore the first and last element (don't need checking).
792 # Intentionally ignore the first and last element (don't need checking).
793 while pointer < len(diffs) - 1:
793 while pointer < len(diffs) - 1:
794 if (diffs[pointer - 1][0] == self.DIFF_EQUAL and
794 if (diffs[pointer - 1][0] == self.DIFF_EQUAL and
795 diffs[pointer + 1][0] == self.DIFF_EQUAL):
795 diffs[pointer + 1][0] == self.DIFF_EQUAL):
796 # This is a single edit surrounded by equalities.
796 # This is a single edit surrounded by equalities.
797 equality1 = diffs[pointer - 1][1]
797 equality1 = diffs[pointer - 1][1]
798 edit = diffs[pointer][1]
798 edit = diffs[pointer][1]
799 equality2 = diffs[pointer + 1][1]
799 equality2 = diffs[pointer + 1][1]
800
800
801 # First, shift the edit as far left as possible.
801 # First, shift the edit as far left as possible.
802 commonOffset = self.diff_commonSuffix(equality1, edit)
802 commonOffset = self.diff_commonSuffix(equality1, edit)
803 if commonOffset:
803 if commonOffset:
804 commonString = edit[-commonOffset:]
804 commonString = edit[-commonOffset:]
805 equality1 = equality1[:-commonOffset]
805 equality1 = equality1[:-commonOffset]
806 edit = commonString + edit[:-commonOffset]
806 edit = commonString + edit[:-commonOffset]
807 equality2 = commonString + equality2
807 equality2 = commonString + equality2
808
808
809 # Second, step character by character right, looking for the best fit.
809 # Second, step character by character right, looking for the best fit.
810 bestEquality1 = equality1
810 bestEquality1 = equality1
811 bestEdit = edit
811 bestEdit = edit
812 bestEquality2 = equality2
812 bestEquality2 = equality2
813 bestScore = (diff_cleanupSemanticScore(equality1, edit) +
813 bestScore = (diff_cleanupSemanticScore(equality1, edit) +
814 diff_cleanupSemanticScore(edit, equality2))
814 diff_cleanupSemanticScore(edit, equality2))
815 while edit and equality2 and edit[0] == equality2[0]:
815 while edit and equality2 and edit[0] == equality2[0]:
816 equality1 += edit[0]
816 equality1 += edit[0]
817 edit = edit[1:] + equality2[0]
817 edit = edit[1:] + equality2[0]
818 equality2 = equality2[1:]
818 equality2 = equality2[1:]
819 score = (diff_cleanupSemanticScore(equality1, edit) +
819 score = (diff_cleanupSemanticScore(equality1, edit) +
820 diff_cleanupSemanticScore(edit, equality2))
820 diff_cleanupSemanticScore(edit, equality2))
821 # The >= encourages trailing rather than leading whitespace on edits.
821 # The >= encourages trailing rather than leading whitespace on edits.
822 if score >= bestScore:
822 if score >= bestScore:
823 bestScore = score
823 bestScore = score
824 bestEquality1 = equality1
824 bestEquality1 = equality1
825 bestEdit = edit
825 bestEdit = edit
826 bestEquality2 = equality2
826 bestEquality2 = equality2
827
827
828 if diffs[pointer - 1][1] != bestEquality1:
828 if diffs[pointer - 1][1] != bestEquality1:
829 # We have an improvement, save it back to the diff.
829 # We have an improvement, save it back to the diff.
830 if bestEquality1:
830 if bestEquality1:
831 diffs[pointer - 1] = (diffs[pointer - 1][0], bestEquality1)
831 diffs[pointer - 1] = (diffs[pointer - 1][0], bestEquality1)
832 else:
832 else:
833 del diffs[pointer - 1]
833 del diffs[pointer - 1]
834 pointer -= 1
834 pointer -= 1
835 diffs[pointer] = (diffs[pointer][0], bestEdit)
835 diffs[pointer] = (diffs[pointer][0], bestEdit)
836 if bestEquality2:
836 if bestEquality2:
837 diffs[pointer + 1] = (diffs[pointer + 1][0], bestEquality2)
837 diffs[pointer + 1] = (diffs[pointer + 1][0], bestEquality2)
838 else:
838 else:
839 del diffs[pointer + 1]
839 del diffs[pointer + 1]
840 pointer -= 1
840 pointer -= 1
841 pointer += 1
841 pointer += 1
842
842
843 # Define some regex patterns for matching boundaries.
843 # Define some regex patterns for matching boundaries.
844 BLANKLINEEND = re.compile(r"\n\r?\n$");
844 BLANKLINEEND = re.compile(r"\n\r?\n$");
845 BLANKLINESTART = re.compile(r"^\r?\n\r?\n");
845 BLANKLINESTART = re.compile(r"^\r?\n\r?\n");
846
846
847 def diff_cleanupEfficiency(self, diffs):
847 def diff_cleanupEfficiency(self, diffs):
848 """Reduce the number of edits by eliminating operationally trivial
848 """Reduce the number of edits by eliminating operationally trivial
849 equalities.
849 equalities.
850
850
851 Args:
851 Args:
852 diffs: Array of diff tuples.
852 diffs: Array of diff tuples.
853 """
853 """
854 changes = False
854 changes = False
855 equalities = [] # Stack of indices where equalities are found.
855 equalities = [] # Stack of indices where equalities are found.
856 lastequality = None # Always equal to diffs[equalities[-1]][1]
856 lastequality = None # Always equal to diffs[equalities[-1]][1]
857 pointer = 0 # Index of current position.
857 pointer = 0 # Index of current position.
858 pre_ins = False # Is there an insertion operation before the last equality.
858 pre_ins = False # Is there an insertion operation before the last equality.
859 pre_del = False # Is there a deletion operation before the last equality.
859 pre_del = False # Is there a deletion operation before the last equality.
860 post_ins = False # Is there an insertion operation after the last equality.
860 post_ins = False # Is there an insertion operation after the last equality.
861 post_del = False # Is there a deletion operation after the last equality.
861 post_del = False # Is there a deletion operation after the last equality.
862 while pointer < len(diffs):
862 while pointer < len(diffs):
863 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
863 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
864 if (len(diffs[pointer][1]) < self.Diff_EditCost and
864 if (len(diffs[pointer][1]) < self.Diff_EditCost and
865 (post_ins or post_del)):
865 (post_ins or post_del)):
866 # Candidate found.
866 # Candidate found.
867 equalities.append(pointer)
867 equalities.append(pointer)
868 pre_ins = post_ins
868 pre_ins = post_ins
869 pre_del = post_del
869 pre_del = post_del
870 lastequality = diffs[pointer][1]
870 lastequality = diffs[pointer][1]
871 else:
871 else:
872 # Not a candidate, and can never become one.
872 # Not a candidate, and can never become one.
873 equalities = []
873 equalities = []
874 lastequality = None
874 lastequality = None
875
875
876 post_ins = post_del = False
876 post_ins = post_del = False
877 else: # An insertion or deletion.
877 else: # An insertion or deletion.
878 if diffs[pointer][0] == self.DIFF_DELETE:
878 if diffs[pointer][0] == self.DIFF_DELETE:
879 post_del = True
879 post_del = True
880 else:
880 else:
881 post_ins = True
881 post_ins = True
882
882
883 # Five types to be split:
883 # Five types to be split:
884 # <ins>A</ins><del>B</del>XY<ins>C</ins><del>D</del>
884 # <ins>A</ins><del>B</del>XY<ins>C</ins><del>D</del>
885 # <ins>A</ins>X<ins>C</ins><del>D</del>
885 # <ins>A</ins>X<ins>C</ins><del>D</del>
886 # <ins>A</ins><del>B</del>X<ins>C</ins>
886 # <ins>A</ins><del>B</del>X<ins>C</ins>
887 # <ins>A</del>X<ins>C</ins><del>D</del>
887 # <ins>A</del>X<ins>C</ins><del>D</del>
888 # <ins>A</ins><del>B</del>X<del>C</del>
888 # <ins>A</ins><del>B</del>X<del>C</del>
889
889
890 if lastequality and ((pre_ins and pre_del and post_ins and post_del) or
890 if lastequality and ((pre_ins and pre_del and post_ins and post_del) or
891 ((len(lastequality) < self.Diff_EditCost / 2) and
891 ((len(lastequality) < self.Diff_EditCost / 2) and
892 (pre_ins + pre_del + post_ins + post_del) == 3)):
892 (pre_ins + pre_del + post_ins + post_del) == 3)):
893 # Duplicate record.
893 # Duplicate record.
894 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
894 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
895 # Change second copy to insert.
895 # Change second copy to insert.
896 diffs[equalities[-1] + 1] = (self.DIFF_INSERT,
896 diffs[equalities[-1] + 1] = (self.DIFF_INSERT,
897 diffs[equalities[-1] + 1][1])
897 diffs[equalities[-1] + 1][1])
898 equalities.pop() # Throw away the equality we just deleted.
898 equalities.pop() # Throw away the equality we just deleted.
899 lastequality = None
899 lastequality = None
900 if pre_ins and pre_del:
900 if pre_ins and pre_del:
901 # No changes made which could affect previous entry, keep going.
901 # No changes made which could affect previous entry, keep going.
902 post_ins = post_del = True
902 post_ins = post_del = True
903 equalities = []
903 equalities = []
904 else:
904 else:
905 if len(equalities):
905 if len(equalities):
906 equalities.pop() # Throw away the previous equality.
906 equalities.pop() # Throw away the previous equality.
907 if len(equalities):
907 if len(equalities):
908 pointer = equalities[-1]
908 pointer = equalities[-1]
909 else:
909 else:
910 pointer = -1
910 pointer = -1
911 post_ins = post_del = False
911 post_ins = post_del = False
912 changes = True
912 changes = True
913 pointer += 1
913 pointer += 1
914
914
915 if changes:
915 if changes:
916 self.diff_cleanupMerge(diffs)
916 self.diff_cleanupMerge(diffs)
917
917
918 def diff_cleanupMerge(self, diffs):
918 def diff_cleanupMerge(self, diffs):
919 """Reorder and merge like edit sections. Merge equalities.
919 """Reorder and merge like edit sections. Merge equalities.
920 Any edit section can move as long as it doesn't cross an equality.
920 Any edit section can move as long as it doesn't cross an equality.
921
921
922 Args:
922 Args:
923 diffs: Array of diff tuples.
923 diffs: Array of diff tuples.
924 """
924 """
925 diffs.append((self.DIFF_EQUAL, '')) # Add a dummy entry at the end.
925 diffs.append((self.DIFF_EQUAL, '')) # Add a dummy entry at the end.
926 pointer = 0
926 pointer = 0
927 count_delete = 0
927 count_delete = 0
928 count_insert = 0
928 count_insert = 0
929 text_delete = ''
929 text_delete = ''
930 text_insert = ''
930 text_insert = ''
931 while pointer < len(diffs):
931 while pointer < len(diffs):
932 if diffs[pointer][0] == self.DIFF_INSERT:
932 if diffs[pointer][0] == self.DIFF_INSERT:
933 count_insert += 1
933 count_insert += 1
934 text_insert += diffs[pointer][1]
934 text_insert += diffs[pointer][1]
935 pointer += 1
935 pointer += 1
936 elif diffs[pointer][0] == self.DIFF_DELETE:
936 elif diffs[pointer][0] == self.DIFF_DELETE:
937 count_delete += 1
937 count_delete += 1
938 text_delete += diffs[pointer][1]
938 text_delete += diffs[pointer][1]
939 pointer += 1
939 pointer += 1
940 elif diffs[pointer][0] == self.DIFF_EQUAL:
940 elif diffs[pointer][0] == self.DIFF_EQUAL:
941 # Upon reaching an equality, check for prior redundancies.
941 # Upon reaching an equality, check for prior redundancies.
942 if count_delete + count_insert > 1:
942 if count_delete + count_insert > 1:
943 if count_delete != 0 and count_insert != 0:
943 if count_delete != 0 and count_insert != 0:
944 # Factor out any common prefixies.
944 # Factor out any common prefixies.
945 commonlength = self.diff_commonPrefix(text_insert, text_delete)
945 commonlength = self.diff_commonPrefix(text_insert, text_delete)
946 if commonlength != 0:
946 if commonlength != 0:
947 x = pointer - count_delete - count_insert - 1
947 x = pointer - count_delete - count_insert - 1
948 if x >= 0 and diffs[x][0] == self.DIFF_EQUAL:
948 if x >= 0 and diffs[x][0] == self.DIFF_EQUAL:
949 diffs[x] = (diffs[x][0], diffs[x][1] +
949 diffs[x] = (diffs[x][0], diffs[x][1] +
950 text_insert[:commonlength])
950 text_insert[:commonlength])
951 else:
951 else:
952 diffs.insert(0, (self.DIFF_EQUAL, text_insert[:commonlength]))
952 diffs.insert(0, (self.DIFF_EQUAL, text_insert[:commonlength]))
953 pointer += 1
953 pointer += 1
954 text_insert = text_insert[commonlength:]
954 text_insert = text_insert[commonlength:]
955 text_delete = text_delete[commonlength:]
955 text_delete = text_delete[commonlength:]
956 # Factor out any common suffixies.
956 # Factor out any common suffixies.
957 commonlength = self.diff_commonSuffix(text_insert, text_delete)
957 commonlength = self.diff_commonSuffix(text_insert, text_delete)
958 if commonlength != 0:
958 if commonlength != 0:
959 diffs[pointer] = (diffs[pointer][0], text_insert[-commonlength:] +
959 diffs[pointer] = (diffs[pointer][0], text_insert[-commonlength:] +
960 diffs[pointer][1])
960 diffs[pointer][1])
961 text_insert = text_insert[:-commonlength]
961 text_insert = text_insert[:-commonlength]
962 text_delete = text_delete[:-commonlength]
962 text_delete = text_delete[:-commonlength]
963 # Delete the offending records and add the merged ones.
963 # Delete the offending records and add the merged ones.
964 if count_delete == 0:
964 if count_delete == 0:
965 diffs[pointer - count_insert : pointer] = [
965 diffs[pointer - count_insert : pointer] = [
966 (self.DIFF_INSERT, text_insert)]
966 (self.DIFF_INSERT, text_insert)]
967 elif count_insert == 0:
967 elif count_insert == 0:
968 diffs[pointer - count_delete : pointer] = [
968 diffs[pointer - count_delete : pointer] = [
969 (self.DIFF_DELETE, text_delete)]
969 (self.DIFF_DELETE, text_delete)]
970 else:
970 else:
971 diffs[pointer - count_delete - count_insert : pointer] = [
971 diffs[pointer - count_delete - count_insert : pointer] = [
972 (self.DIFF_DELETE, text_delete),
972 (self.DIFF_DELETE, text_delete),
973 (self.DIFF_INSERT, text_insert)]
973 (self.DIFF_INSERT, text_insert)]
974 pointer = pointer - count_delete - count_insert + 1
974 pointer = pointer - count_delete - count_insert + 1
975 if count_delete != 0:
975 if count_delete != 0:
976 pointer += 1
976 pointer += 1
977 if count_insert != 0:
977 if count_insert != 0:
978 pointer += 1
978 pointer += 1
979 elif pointer != 0 and diffs[pointer - 1][0] == self.DIFF_EQUAL:
979 elif pointer != 0 and diffs[pointer - 1][0] == self.DIFF_EQUAL:
980 # Merge this equality with the previous one.
980 # Merge this equality with the previous one.
981 diffs[pointer - 1] = (diffs[pointer - 1][0],
981 diffs[pointer - 1] = (diffs[pointer - 1][0],
982 diffs[pointer - 1][1] + diffs[pointer][1])
982 diffs[pointer - 1][1] + diffs[pointer][1])
983 del diffs[pointer]
983 del diffs[pointer]
984 else:
984 else:
985 pointer += 1
985 pointer += 1
986
986
987 count_insert = 0
987 count_insert = 0
988 count_delete = 0
988 count_delete = 0
989 text_delete = ''
989 text_delete = ''
990 text_insert = ''
990 text_insert = ''
991
991
992 if diffs[-1][1] == '':
992 if diffs[-1][1] == '':
993 diffs.pop() # Remove the dummy entry at the end.
993 diffs.pop() # Remove the dummy entry at the end.
994
994
995 # Second pass: look for single edits surrounded on both sides by equalities
995 # Second pass: look for single edits surrounded on both sides by equalities
996 # which can be shifted sideways to eliminate an equality.
996 # which can be shifted sideways to eliminate an equality.
997 # e.g: A<ins>BA</ins>C -> <ins>AB</ins>AC
997 # e.g: A<ins>BA</ins>C -> <ins>AB</ins>AC
998 changes = False
998 changes = False
999 pointer = 1
999 pointer = 1
1000 # Intentionally ignore the first and last element (don't need checking).
1000 # Intentionally ignore the first and last element (don't need checking).
1001 while pointer < len(diffs) - 1:
1001 while pointer < len(diffs) - 1:
1002 if (diffs[pointer - 1][0] == self.DIFF_EQUAL and
1002 if (diffs[pointer - 1][0] == self.DIFF_EQUAL and
1003 diffs[pointer + 1][0] == self.DIFF_EQUAL):
1003 diffs[pointer + 1][0] == self.DIFF_EQUAL):
1004 # This is a single edit surrounded by equalities.
1004 # This is a single edit surrounded by equalities.
1005 if diffs[pointer][1].endswith(diffs[pointer - 1][1]):
1005 if diffs[pointer][1].endswith(diffs[pointer - 1][1]):
1006 # Shift the edit over the previous equality.
1006 # Shift the edit over the previous equality.
1007 diffs[pointer] = (diffs[pointer][0],
1007 diffs[pointer] = (diffs[pointer][0],
1008 diffs[pointer - 1][1] +
1008 diffs[pointer - 1][1] +
1009 diffs[pointer][1][:-len(diffs[pointer - 1][1])])
1009 diffs[pointer][1][:-len(diffs[pointer - 1][1])])
1010 diffs[pointer + 1] = (diffs[pointer + 1][0],
1010 diffs[pointer + 1] = (diffs[pointer + 1][0],
1011 diffs[pointer - 1][1] + diffs[pointer + 1][1])
1011 diffs[pointer - 1][1] + diffs[pointer + 1][1])
1012 del diffs[pointer - 1]
1012 del diffs[pointer - 1]
1013 changes = True
1013 changes = True
1014 elif diffs[pointer][1].startswith(diffs[pointer + 1][1]):
1014 elif diffs[pointer][1].startswith(diffs[pointer + 1][1]):
1015 # Shift the edit over the next equality.
1015 # Shift the edit over the next equality.
1016 diffs[pointer - 1] = (diffs[pointer - 1][0],
1016 diffs[pointer - 1] = (diffs[pointer - 1][0],
1017 diffs[pointer - 1][1] + diffs[pointer + 1][1])
1017 diffs[pointer - 1][1] + diffs[pointer + 1][1])
1018 diffs[pointer] = (diffs[pointer][0],
1018 diffs[pointer] = (diffs[pointer][0],
1019 diffs[pointer][1][len(diffs[pointer + 1][1]):] +
1019 diffs[pointer][1][len(diffs[pointer + 1][1]):] +
1020 diffs[pointer + 1][1])
1020 diffs[pointer + 1][1])
1021 del diffs[pointer + 1]
1021 del diffs[pointer + 1]
1022 changes = True
1022 changes = True
1023 pointer += 1
1023 pointer += 1
1024
1024
1025 # If shifts were made, the diff needs reordering and another shift sweep.
1025 # If shifts were made, the diff needs reordering and another shift sweep.
1026 if changes:
1026 if changes:
1027 self.diff_cleanupMerge(diffs)
1027 self.diff_cleanupMerge(diffs)
1028
1028
1029 def diff_xIndex(self, diffs, loc):
1029 def diff_xIndex(self, diffs, loc):
1030 """loc is a location in text1, compute and return the equivalent location
1030 """loc is a location in text1, compute and return the equivalent location
1031 in text2. e.g. "The cat" vs "The big cat", 1->1, 5->8
1031 in text2. e.g. "The cat" vs "The big cat", 1->1, 5->8
1032
1032
1033 Args:
1033 Args:
1034 diffs: Array of diff tuples.
1034 diffs: Array of diff tuples.
1035 loc: Location within text1.
1035 loc: Location within text1.
1036
1036
1037 Returns:
1037 Returns:
1038 Location within text2.
1038 Location within text2.
1039 """
1039 """
1040 chars1 = 0
1040 chars1 = 0
1041 chars2 = 0
1041 chars2 = 0
1042 last_chars1 = 0
1042 last_chars1 = 0
1043 last_chars2 = 0
1043 last_chars2 = 0
1044 for x in xrange(len(diffs)):
1044 for x in xrange(len(diffs)):
1045 (op, text) = diffs[x]
1045 (op, text) = diffs[x]
1046 if op != self.DIFF_INSERT: # Equality or deletion.
1046 if op != self.DIFF_INSERT: # Equality or deletion.
1047 chars1 += len(text)
1047 chars1 += len(text)
1048 if op != self.DIFF_DELETE: # Equality or insertion.
1048 if op != self.DIFF_DELETE: # Equality or insertion.
1049 chars2 += len(text)
1049 chars2 += len(text)
1050 if chars1 > loc: # Overshot the location.
1050 if chars1 > loc: # Overshot the location.
1051 break
1051 break
1052 last_chars1 = chars1
1052 last_chars1 = chars1
1053 last_chars2 = chars2
1053 last_chars2 = chars2
1054
1054
1055 if len(diffs) != x and diffs[x][0] == self.DIFF_DELETE:
1055 if len(diffs) != x and diffs[x][0] == self.DIFF_DELETE:
1056 # The location was deleted.
1056 # The location was deleted.
1057 return last_chars2
1057 return last_chars2
1058 # Add the remaining len(character).
1058 # Add the remaining len(character).
1059 return last_chars2 + (loc - last_chars1)
1059 return last_chars2 + (loc - last_chars1)
1060
1060
1061 def diff_prettyHtml(self, diffs):
1061 def diff_prettyHtml(self, diffs):
1062 """Convert a diff array into a pretty HTML report.
1062 """Convert a diff array into a pretty HTML report.
1063
1063
1064 Args:
1064 Args:
1065 diffs: Array of diff tuples.
1065 diffs: Array of diff tuples.
1066
1066
1067 Returns:
1067 Returns:
1068 HTML representation.
1068 HTML representation.
1069 """
1069 """
1070 html = []
1070 html = []
1071 for (op, data) in diffs:
1071 for (op, data) in diffs:
1072 text = (data.replace("&", "&amp;").replace("<", "&lt;")
1072 text = (data.replace("&", "&amp;").replace("<", "&lt;")
1073 .replace(">", "&gt;").replace("\n", "&para;<br>"))
1073 .replace(">", "&gt;").replace("\n", "&para;<br>"))
1074 if op == self.DIFF_INSERT:
1074 if op == self.DIFF_INSERT:
1075 html.append("<ins style=\"background:#e6ffe6;\">%s</ins>" % text)
1075 html.append("<ins style=\"background:#e6ffe6;\">%s</ins>" % text)
1076 elif op == self.DIFF_DELETE:
1076 elif op == self.DIFF_DELETE:
1077 html.append("<del style=\"background:#ffe6e6;\">%s</del>" % text)
1077 html.append("<del style=\"background:#ffe6e6;\">%s</del>" % text)
1078 elif op == self.DIFF_EQUAL:
1078 elif op == self.DIFF_EQUAL:
1079 html.append("<span>%s</span>" % text)
1079 html.append("<span>%s</span>" % text)
1080 return "".join(html)
1080 return "".join(html)
1081
1081
1082 def diff_text1(self, diffs):
1082 def diff_text1(self, diffs):
1083 """Compute and return the source text (all equalities and deletions).
1083 """Compute and return the source text (all equalities and deletions).
1084
1084
1085 Args:
1085 Args:
1086 diffs: Array of diff tuples.
1086 diffs: Array of diff tuples.
1087
1087
1088 Returns:
1088 Returns:
1089 Source text.
1089 Source text.
1090 """
1090 """
1091 text = []
1091 text = []
1092 for (op, data) in diffs:
1092 for (op, data) in diffs:
1093 if op != self.DIFF_INSERT:
1093 if op != self.DIFF_INSERT:
1094 text.append(data)
1094 text.append(data)
1095 return "".join(text)
1095 return "".join(text)
1096
1096
1097 def diff_text2(self, diffs):
1097 def diff_text2(self, diffs):
1098 """Compute and return the destination text (all equalities and insertions).
1098 """Compute and return the destination text (all equalities and insertions).
1099
1099
1100 Args:
1100 Args:
1101 diffs: Array of diff tuples.
1101 diffs: Array of diff tuples.
1102
1102
1103 Returns:
1103 Returns:
1104 Destination text.
1104 Destination text.
1105 """
1105 """
1106 text = []
1106 text = []
1107 for (op, data) in diffs:
1107 for (op, data) in diffs:
1108 if op != self.DIFF_DELETE:
1108 if op != self.DIFF_DELETE:
1109 text.append(data)
1109 text.append(data)
1110 return "".join(text)
1110 return "".join(text)
1111
1111
1112 def diff_levenshtein(self, diffs):
1112 def diff_levenshtein(self, diffs):
1113 """Compute the Levenshtein distance; the number of inserted, deleted or
1113 """Compute the Levenshtein distance; the number of inserted, deleted or
1114 substituted characters.
1114 substituted characters.
1115
1115
1116 Args:
1116 Args:
1117 diffs: Array of diff tuples.
1117 diffs: Array of diff tuples.
1118
1118
1119 Returns:
1119 Returns:
1120 Number of changes.
1120 Number of changes.
1121 """
1121 """
1122 levenshtein = 0
1122 levenshtein = 0
1123 insertions = 0
1123 insertions = 0
1124 deletions = 0
1124 deletions = 0
1125 for (op, data) in diffs:
1125 for (op, data) in diffs:
1126 if op == self.DIFF_INSERT:
1126 if op == self.DIFF_INSERT:
1127 insertions += len(data)
1127 insertions += len(data)
1128 elif op == self.DIFF_DELETE:
1128 elif op == self.DIFF_DELETE:
1129 deletions += len(data)
1129 deletions += len(data)
1130 elif op == self.DIFF_EQUAL:
1130 elif op == self.DIFF_EQUAL:
1131 # A deletion and an insertion is one substitution.
1131 # A deletion and an insertion is one substitution.
1132 levenshtein += max(insertions, deletions)
1132 levenshtein += max(insertions, deletions)
1133 insertions = 0
1133 insertions = 0
1134 deletions = 0
1134 deletions = 0
1135 levenshtein += max(insertions, deletions)
1135 levenshtein += max(insertions, deletions)
1136 return levenshtein
1136 return levenshtein
1137
1137
1138 def diff_toDelta(self, diffs):
1138 def diff_toDelta(self, diffs):
1139 """Crush the diff into an encoded string which describes the operations
1139 """Crush the diff into an encoded string which describes the operations
1140 required to transform text1 into text2.
1140 required to transform text1 into text2.
1141 E.g. =3\t-2\t+ing -> Keep 3 chars, delete 2 chars, insert 'ing'.
1141 E.g. =3\t-2\t+ing -> Keep 3 chars, delete 2 chars, insert 'ing'.
1142 Operations are tab-separated. Inserted text is escaped using %xx notation.
1142 Operations are tab-separated. Inserted text is escaped using %xx notation.
1143
1143
1144 Args:
1144 Args:
1145 diffs: Array of diff tuples.
1145 diffs: Array of diff tuples.
1146
1146
1147 Returns:
1147 Returns:
1148 Delta text.
1148 Delta text.
1149 """
1149 """
1150 text = []
1150 text = []
1151 for (op, data) in diffs:
1151 for (op, data) in diffs:
1152 if op == self.DIFF_INSERT:
1152 if op == self.DIFF_INSERT:
1153 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1153 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1154 data = data.encode("utf-8")
1154 data = data.encode("utf-8")
1155 text.append("+" + urllib.quote(data, "!~*'();/?:@&=+$,# "))
1155 text.append("+" + urllib.quote(data, "!~*'();/?:@&=+$,# "))
1156 elif op == self.DIFF_DELETE:
1156 elif op == self.DIFF_DELETE:
1157 text.append("-%d" % len(data))
1157 text.append("-%d" % len(data))
1158 elif op == self.DIFF_EQUAL:
1158 elif op == self.DIFF_EQUAL:
1159 text.append("=%d" % len(data))
1159 text.append("=%d" % len(data))
1160 return "\t".join(text)
1160 return "\t".join(text)
1161
1161
1162 def diff_fromDelta(self, text1, delta):
1162 def diff_fromDelta(self, text1, delta):
1163 """Given the original text1, and an encoded string which describes the
1163 """Given the original text1, and an encoded string which describes the
1164 operations required to transform text1 into text2, compute the full diff.
1164 operations required to transform text1 into text2, compute the full diff.
1165
1165
1166 Args:
1166 Args:
1167 text1: Source string for the diff.
1167 text1: Source string for the diff.
1168 delta: Delta text.
1168 delta: Delta text.
1169
1169
1170 Returns:
1170 Returns:
1171 Array of diff tuples.
1171 Array of diff tuples.
1172
1172
1173 Raises:
1173 Raises:
1174 ValueError: If invalid input.
1174 ValueError: If invalid input.
1175 """
1175 """
1176 if type(delta) == unicode:
1176 if type(delta) == unicode:
1177 # Deltas should be composed of a subset of ascii chars, Unicode not
1177 # Deltas should be composed of a subset of ascii chars, Unicode not
1178 # required. If this encode raises UnicodeEncodeError, delta is invalid.
1178 # required. If this encode raises UnicodeEncodeError, delta is invalid.
1179 delta = delta.encode("ascii")
1179 delta = delta.encode("ascii")
1180 diffs = []
1180 diffs = []
1181 pointer = 0 # Cursor in text1
1181 pointer = 0 # Cursor in text1
1182 tokens = delta.split("\t")
1182 tokens = delta.split("\t")
1183 for token in tokens:
1183 for token in tokens:
1184 if token == "":
1184 if token == "":
1185 # Blank tokens are ok (from a trailing \t).
1185 # Blank tokens are ok (from a trailing \t).
1186 continue
1186 continue
1187 # Each token begins with a one character parameter which specifies the
1187 # Each token begins with a one character parameter which specifies the
1188 # operation of this token (delete, insert, equality).
1188 # operation of this token (delete, insert, equality).
1189 param = token[1:]
1189 param = token[1:]
1190 if token[0] == "+":
1190 if token[0] == "+":
1191 param = urllib.unquote(param).decode("utf-8")
1191 param = urllib.unquote(param).decode("utf-8")
1192 diffs.append((self.DIFF_INSERT, param))
1192 diffs.append((self.DIFF_INSERT, param))
1193 elif token[0] == "-" or token[0] == "=":
1193 elif token[0] == "-" or token[0] == "=":
1194 try:
1194 try:
1195 n = int(param)
1195 n = int(param)
1196 except ValueError:
1196 except ValueError:
1197 raise ValueError("Invalid number in diff_fromDelta: " + param)
1197 raise ValueError("Invalid number in diff_fromDelta: " + param)
1198 if n < 0:
1198 if n < 0:
1199 raise ValueError("Negative number in diff_fromDelta: " + param)
1199 raise ValueError("Negative number in diff_fromDelta: " + param)
1200 text = text1[pointer : pointer + n]
1200 text = text1[pointer : pointer + n]
1201 pointer += n
1201 pointer += n
1202 if token[0] == "=":
1202 if token[0] == "=":
1203 diffs.append((self.DIFF_EQUAL, text))
1203 diffs.append((self.DIFF_EQUAL, text))
1204 else:
1204 else:
1205 diffs.append((self.DIFF_DELETE, text))
1205 diffs.append((self.DIFF_DELETE, text))
1206 else:
1206 else:
1207 # Anything else is an error.
1207 # Anything else is an error.
1208 raise ValueError("Invalid diff operation in diff_fromDelta: " +
1208 raise ValueError("Invalid diff operation in diff_fromDelta: " +
1209 token[0])
1209 token[0])
1210 if pointer != len(text1):
1210 if pointer != len(text1):
1211 raise ValueError(
1211 raise ValueError(
1212 "Delta length (%d) does not equal source text length (%d)." %
1212 "Delta length (%d) does not equal source text length (%d)." %
1213 (pointer, len(text1)))
1213 (pointer, len(text1)))
1214 return diffs
1214 return diffs
1215
1215
1216 # MATCH FUNCTIONS
1216 # MATCH FUNCTIONS
1217
1217
1218 def match_main(self, text, pattern, loc):
1218 def match_main(self, text, pattern, loc):
1219 """Locate the best instance of 'pattern' in 'text' near 'loc'.
1219 """Locate the best instance of 'pattern' in 'text' near 'loc'.
1220
1220
1221 Args:
1221 Args:
1222 text: The text to search.
1222 text: The text to search.
1223 pattern: The pattern to search for.
1223 pattern: The pattern to search for.
1224 loc: The location to search around.
1224 loc: The location to search around.
1225
1225
1226 Returns:
1226 Returns:
1227 Best match index or -1.
1227 Best match index or -1.
1228 """
1228 """
1229 # Check for null inputs.
1229 # Check for null inputs.
1230 if text == None or pattern == None:
1230 if text is None or pattern is None:
1231 raise ValueError("Null inputs. (match_main)")
1231 raise ValueError("Null inputs. (match_main)")
1232
1232
1233 loc = max(0, min(loc, len(text)))
1233 loc = max(0, min(loc, len(text)))
1234 if text == pattern:
1234 if text == pattern:
1235 # Shortcut (potentially not guaranteed by the algorithm)
1235 # Shortcut (potentially not guaranteed by the algorithm)
1236 return 0
1236 return 0
1237 elif not text:
1237 elif not text:
1238 # Nothing to match.
1238 # Nothing to match.
1239 return -1
1239 return -1
1240 elif text[loc:loc + len(pattern)] == pattern:
1240 elif text[loc:loc + len(pattern)] == pattern:
1241 # Perfect match at the perfect spot! (Includes case of null pattern)
1241 # Perfect match at the perfect spot! (Includes case of null pattern)
1242 return loc
1242 return loc
1243 else:
1243 else:
1244 # Do a fuzzy compare.
1244 # Do a fuzzy compare.
1245 match = self.match_bitap(text, pattern, loc)
1245 match = self.match_bitap(text, pattern, loc)
1246 return match
1246 return match
1247
1247
1248 def match_bitap(self, text, pattern, loc):
1248 def match_bitap(self, text, pattern, loc):
1249 """Locate the best instance of 'pattern' in 'text' near 'loc' using the
1249 """Locate the best instance of 'pattern' in 'text' near 'loc' using the
1250 Bitap algorithm.
1250 Bitap algorithm.
1251
1251
1252 Args:
1252 Args:
1253 text: The text to search.
1253 text: The text to search.
1254 pattern: The pattern to search for.
1254 pattern: The pattern to search for.
1255 loc: The location to search around.
1255 loc: The location to search around.
1256
1256
1257 Returns:
1257 Returns:
1258 Best match index or -1.
1258 Best match index or -1.
1259 """
1259 """
1260 # Python doesn't have a maxint limit, so ignore this check.
1260 # Python doesn't have a maxint limit, so ignore this check.
1261 #if self.Match_MaxBits != 0 and len(pattern) > self.Match_MaxBits:
1261 #if self.Match_MaxBits != 0 and len(pattern) > self.Match_MaxBits:
1262 # raise ValueError("Pattern too long for this application.")
1262 # raise ValueError("Pattern too long for this application.")
1263
1263
1264 # Initialise the alphabet.
1264 # Initialise the alphabet.
1265 s = self.match_alphabet(pattern)
1265 s = self.match_alphabet(pattern)
1266
1266
1267 def match_bitapScore(e, x):
1267 def match_bitapScore(e, x):
1268 """Compute and return the score for a match with e errors and x location.
1268 """Compute and return the score for a match with e errors and x location.
1269 Accesses loc and pattern through being a closure.
1269 Accesses loc and pattern through being a closure.
1270
1270
1271 Args:
1271 Args:
1272 e: Number of errors in match.
1272 e: Number of errors in match.
1273 x: Location of match.
1273 x: Location of match.
1274
1274
1275 Returns:
1275 Returns:
1276 Overall score for match (0.0 = good, 1.0 = bad).
1276 Overall score for match (0.0 = good, 1.0 = bad).
1277 """
1277 """
1278 accuracy = float(e) / len(pattern)
1278 accuracy = float(e) / len(pattern)
1279 proximity = abs(loc - x)
1279 proximity = abs(loc - x)
1280 if not self.Match_Distance:
1280 if not self.Match_Distance:
1281 # Dodge divide by zero error.
1281 # Dodge divide by zero error.
1282 return proximity and 1.0 or accuracy
1282 return proximity and 1.0 or accuracy
1283 return accuracy + (proximity / float(self.Match_Distance))
1283 return accuracy + (proximity / float(self.Match_Distance))
1284
1284
1285 # Highest score beyond which we give up.
1285 # Highest score beyond which we give up.
1286 score_threshold = self.Match_Threshold
1286 score_threshold = self.Match_Threshold
1287 # Is there a nearby exact match? (speedup)
1287 # Is there a nearby exact match? (speedup)
1288 best_loc = text.find(pattern, loc)
1288 best_loc = text.find(pattern, loc)
1289 if best_loc != -1:
1289 if best_loc != -1:
1290 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1290 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1291 # What about in the other direction? (speedup)
1291 # What about in the other direction? (speedup)
1292 best_loc = text.rfind(pattern, loc + len(pattern))
1292 best_loc = text.rfind(pattern, loc + len(pattern))
1293 if best_loc != -1:
1293 if best_loc != -1:
1294 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1294 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1295
1295
1296 # Initialise the bit arrays.
1296 # Initialise the bit arrays.
1297 matchmask = 1 << (len(pattern) - 1)
1297 matchmask = 1 << (len(pattern) - 1)
1298 best_loc = -1
1298 best_loc = -1
1299
1299
1300 bin_max = len(pattern) + len(text)
1300 bin_max = len(pattern) + len(text)
1301 # Empty initialization added to appease pychecker.
1301 # Empty initialization added to appease pychecker.
1302 last_rd = None
1302 last_rd = None
1303 for d in xrange(len(pattern)):
1303 for d in xrange(len(pattern)):
1304 # Scan for the best match each iteration allows for one more error.
1304 # Scan for the best match each iteration allows for one more error.
1305 # Run a binary search to determine how far from 'loc' we can stray at
1305 # Run a binary search to determine how far from 'loc' we can stray at
1306 # this error level.
1306 # this error level.
1307 bin_min = 0
1307 bin_min = 0
1308 bin_mid = bin_max
1308 bin_mid = bin_max
1309 while bin_min < bin_mid:
1309 while bin_min < bin_mid:
1310 if match_bitapScore(d, loc + bin_mid) <= score_threshold:
1310 if match_bitapScore(d, loc + bin_mid) <= score_threshold:
1311 bin_min = bin_mid
1311 bin_min = bin_mid
1312 else:
1312 else:
1313 bin_max = bin_mid
1313 bin_max = bin_mid
1314 bin_mid = (bin_max - bin_min) // 2 + bin_min
1314 bin_mid = (bin_max - bin_min) // 2 + bin_min
1315
1315
1316 # Use the result from this iteration as the maximum for the next.
1316 # Use the result from this iteration as the maximum for the next.
1317 bin_max = bin_mid
1317 bin_max = bin_mid
1318 start = max(1, loc - bin_mid + 1)
1318 start = max(1, loc - bin_mid + 1)
1319 finish = min(loc + bin_mid, len(text)) + len(pattern)
1319 finish = min(loc + bin_mid, len(text)) + len(pattern)
1320
1320
1321 rd = [0] * (finish + 2)
1321 rd = [0] * (finish + 2)
1322 rd[finish + 1] = (1 << d) - 1
1322 rd[finish + 1] = (1 << d) - 1
1323 for j in xrange(finish, start - 1, -1):
1323 for j in xrange(finish, start - 1, -1):
1324 if len(text) <= j - 1:
1324 if len(text) <= j - 1:
1325 # Out of range.
1325 # Out of range.
1326 charMatch = 0
1326 charMatch = 0
1327 else:
1327 else:
1328 charMatch = s.get(text[j - 1], 0)
1328 charMatch = s.get(text[j - 1], 0)
1329 if d == 0: # First pass: exact match.
1329 if d == 0: # First pass: exact match.
1330 rd[j] = ((rd[j + 1] << 1) | 1) & charMatch
1330 rd[j] = ((rd[j + 1] << 1) | 1) & charMatch
1331 else: # Subsequent passes: fuzzy match.
1331 else: # Subsequent passes: fuzzy match.
1332 rd[j] = (((rd[j + 1] << 1) | 1) & charMatch) | (
1332 rd[j] = (((rd[j + 1] << 1) | 1) & charMatch) | (
1333 ((last_rd[j + 1] | last_rd[j]) << 1) | 1) | last_rd[j + 1]
1333 ((last_rd[j + 1] | last_rd[j]) << 1) | 1) | last_rd[j + 1]
1334 if rd[j] & matchmask:
1334 if rd[j] & matchmask:
1335 score = match_bitapScore(d, j - 1)
1335 score = match_bitapScore(d, j - 1)
1336 # This match will almost certainly be better than any existing match.
1336 # This match will almost certainly be better than any existing match.
1337 # But check anyway.
1337 # But check anyway.
1338 if score <= score_threshold:
1338 if score <= score_threshold:
1339 # Told you so.
1339 # Told you so.
1340 score_threshold = score
1340 score_threshold = score
1341 best_loc = j - 1
1341 best_loc = j - 1
1342 if best_loc > loc:
1342 if best_loc > loc:
1343 # When passing loc, don't exceed our current distance from loc.
1343 # When passing loc, don't exceed our current distance from loc.
1344 start = max(1, 2 * loc - best_loc)
1344 start = max(1, 2 * loc - best_loc)
1345 else:
1345 else:
1346 # Already passed loc, downhill from here on in.
1346 # Already passed loc, downhill from here on in.
1347 break
1347 break
1348 # No hope for a (better) match at greater error levels.
1348 # No hope for a (better) match at greater error levels.
1349 if match_bitapScore(d + 1, loc) > score_threshold:
1349 if match_bitapScore(d + 1, loc) > score_threshold:
1350 break
1350 break
1351 last_rd = rd
1351 last_rd = rd
1352 return best_loc
1352 return best_loc
1353
1353
1354 def match_alphabet(self, pattern):
1354 def match_alphabet(self, pattern):
1355 """Initialise the alphabet for the Bitap algorithm.
1355 """Initialise the alphabet for the Bitap algorithm.
1356
1356
1357 Args:
1357 Args:
1358 pattern: The text to encode.
1358 pattern: The text to encode.
1359
1359
1360 Returns:
1360 Returns:
1361 Hash of character locations.
1361 Hash of character locations.
1362 """
1362 """
1363 s = {}
1363 s = {}
1364 for char in pattern:
1364 for char in pattern:
1365 s[char] = 0
1365 s[char] = 0
1366 for i in xrange(len(pattern)):
1366 for i in xrange(len(pattern)):
1367 s[pattern[i]] |= 1 << (len(pattern) - i - 1)
1367 s[pattern[i]] |= 1 << (len(pattern) - i - 1)
1368 return s
1368 return s
1369
1369
1370 # PATCH FUNCTIONS
1370 # PATCH FUNCTIONS
1371
1371
1372 def patch_addContext(self, patch, text):
1372 def patch_addContext(self, patch, text):
1373 """Increase the context until it is unique,
1373 """Increase the context until it is unique,
1374 but don't let the pattern expand beyond Match_MaxBits.
1374 but don't let the pattern expand beyond Match_MaxBits.
1375
1375
1376 Args:
1376 Args:
1377 patch: The patch to grow.
1377 patch: The patch to grow.
1378 text: Source text.
1378 text: Source text.
1379 """
1379 """
1380 if len(text) == 0:
1380 if len(text) == 0:
1381 return
1381 return
1382 pattern = text[patch.start2 : patch.start2 + patch.length1]
1382 pattern = text[patch.start2 : patch.start2 + patch.length1]
1383 padding = 0
1383 padding = 0
1384
1384
1385 # Look for the first and last matches of pattern in text. If two different
1385 # Look for the first and last matches of pattern in text. If two different
1386 # matches are found, increase the pattern length.
1386 # matches are found, increase the pattern length.
1387 while (text.find(pattern) != text.rfind(pattern) and (self.Match_MaxBits ==
1387 while (text.find(pattern) != text.rfind(pattern) and (self.Match_MaxBits ==
1388 0 or len(pattern) < self.Match_MaxBits - self.Patch_Margin -
1388 0 or len(pattern) < self.Match_MaxBits - self.Patch_Margin -
1389 self.Patch_Margin)):
1389 self.Patch_Margin)):
1390 padding += self.Patch_Margin
1390 padding += self.Patch_Margin
1391 pattern = text[max(0, patch.start2 - padding) :
1391 pattern = text[max(0, patch.start2 - padding) :
1392 patch.start2 + patch.length1 + padding]
1392 patch.start2 + patch.length1 + padding]
1393 # Add one chunk for good luck.
1393 # Add one chunk for good luck.
1394 padding += self.Patch_Margin
1394 padding += self.Patch_Margin
1395
1395
1396 # Add the prefix.
1396 # Add the prefix.
1397 prefix = text[max(0, patch.start2 - padding) : patch.start2]
1397 prefix = text[max(0, patch.start2 - padding) : patch.start2]
1398 if prefix:
1398 if prefix:
1399 patch.diffs[:0] = [(self.DIFF_EQUAL, prefix)]
1399 patch.diffs[:0] = [(self.DIFF_EQUAL, prefix)]
1400 # Add the suffix.
1400 # Add the suffix.
1401 suffix = text[patch.start2 + patch.length1 :
1401 suffix = text[patch.start2 + patch.length1 :
1402 patch.start2 + patch.length1 + padding]
1402 patch.start2 + patch.length1 + padding]
1403 if suffix:
1403 if suffix:
1404 patch.diffs.append((self.DIFF_EQUAL, suffix))
1404 patch.diffs.append((self.DIFF_EQUAL, suffix))
1405
1405
1406 # Roll back the start points.
1406 # Roll back the start points.
1407 patch.start1 -= len(prefix)
1407 patch.start1 -= len(prefix)
1408 patch.start2 -= len(prefix)
1408 patch.start2 -= len(prefix)
1409 # Extend lengths.
1409 # Extend lengths.
1410 patch.length1 += len(prefix) + len(suffix)
1410 patch.length1 += len(prefix) + len(suffix)
1411 patch.length2 += len(prefix) + len(suffix)
1411 patch.length2 += len(prefix) + len(suffix)
1412
1412
1413 def patch_make(self, a, b=None, c=None):
1413 def patch_make(self, a, b=None, c=None):
1414 """Compute a list of patches to turn text1 into text2.
1414 """Compute a list of patches to turn text1 into text2.
1415 Use diffs if provided, otherwise compute it ourselves.
1415 Use diffs if provided, otherwise compute it ourselves.
1416 There are four ways to call this function, depending on what data is
1416 There are four ways to call this function, depending on what data is
1417 available to the caller:
1417 available to the caller:
1418 Method 1:
1418 Method 1:
1419 a = text1, b = text2
1419 a = text1, b = text2
1420 Method 2:
1420 Method 2:
1421 a = diffs
1421 a = diffs
1422 Method 3 (optimal):
1422 Method 3 (optimal):
1423 a = text1, b = diffs
1423 a = text1, b = diffs
1424 Method 4 (deprecated, use method 3):
1424 Method 4 (deprecated, use method 3):
1425 a = text1, b = text2, c = diffs
1425 a = text1, b = text2, c = diffs
1426
1426
1427 Args:
1427 Args:
1428 a: text1 (methods 1,3,4) or Array of diff tuples for text1 to
1428 a: text1 (methods 1,3,4) or Array of diff tuples for text1 to
1429 text2 (method 2).
1429 text2 (method 2).
1430 b: text2 (methods 1,4) or Array of diff tuples for text1 to
1430 b: text2 (methods 1,4) or Array of diff tuples for text1 to
1431 text2 (method 3) or undefined (method 2).
1431 text2 (method 3) or undefined (method 2).
1432 c: Array of diff tuples for text1 to text2 (method 4) or
1432 c: Array of diff tuples for text1 to text2 (method 4) or
1433 undefined (methods 1,2,3).
1433 undefined (methods 1,2,3).
1434
1434
1435 Returns:
1435 Returns:
1436 Array of Patch objects.
1436 Array of Patch objects.
1437 """
1437 """
1438 text1 = None
1438 text1 = None
1439 diffs = None
1439 diffs = None
1440 # Note that texts may arrive as 'str' or 'unicode'.
1440 # Note that texts may arrive as 'str' or 'unicode'.
1441 if isinstance(a, basestring) and isinstance(b, basestring) and c is None:
1441 if isinstance(a, basestring) and isinstance(b, basestring) and c is None:
1442 # Method 1: text1, text2
1442 # Method 1: text1, text2
1443 # Compute diffs from text1 and text2.
1443 # Compute diffs from text1 and text2.
1444 text1 = a
1444 text1 = a
1445 diffs = self.diff_main(text1, b, True)
1445 diffs = self.diff_main(text1, b, True)
1446 if len(diffs) > 2:
1446 if len(diffs) > 2:
1447 self.diff_cleanupSemantic(diffs)
1447 self.diff_cleanupSemantic(diffs)
1448 self.diff_cleanupEfficiency(diffs)
1448 self.diff_cleanupEfficiency(diffs)
1449 elif isinstance(a, list) and b is None and c is None:
1449 elif isinstance(a, list) and b is None and c is None:
1450 # Method 2: diffs
1450 # Method 2: diffs
1451 # Compute text1 from diffs.
1451 # Compute text1 from diffs.
1452 diffs = a
1452 diffs = a
1453 text1 = self.diff_text1(diffs)
1453 text1 = self.diff_text1(diffs)
1454 elif isinstance(a, basestring) and isinstance(b, list) and c is None:
1454 elif isinstance(a, basestring) and isinstance(b, list) and c is None:
1455 # Method 3: text1, diffs
1455 # Method 3: text1, diffs
1456 text1 = a
1456 text1 = a
1457 diffs = b
1457 diffs = b
1458 elif (isinstance(a, basestring) and isinstance(b, basestring) and
1458 elif (isinstance(a, basestring) and isinstance(b, basestring) and
1459 isinstance(c, list)):
1459 isinstance(c, list)):
1460 # Method 4: text1, text2, diffs
1460 # Method 4: text1, text2, diffs
1461 # text2 is not used.
1461 # text2 is not used.
1462 text1 = a
1462 text1 = a
1463 diffs = c
1463 diffs = c
1464 else:
1464 else:
1465 raise ValueError("Unknown call format to patch_make.")
1465 raise ValueError("Unknown call format to patch_make.")
1466
1466
1467 if not diffs:
1467 if not diffs:
1468 return [] # Get rid of the None case.
1468 return [] # Get rid of the None case.
1469 patches = []
1469 patches = []
1470 patch = patch_obj()
1470 patch = patch_obj()
1471 char_count1 = 0 # Number of characters into the text1 string.
1471 char_count1 = 0 # Number of characters into the text1 string.
1472 char_count2 = 0 # Number of characters into the text2 string.
1472 char_count2 = 0 # Number of characters into the text2 string.
1473 prepatch_text = text1 # Recreate the patches to determine context info.
1473 prepatch_text = text1 # Recreate the patches to determine context info.
1474 postpatch_text = text1
1474 postpatch_text = text1
1475 for x in xrange(len(diffs)):
1475 for x in xrange(len(diffs)):
1476 (diff_type, diff_text) = diffs[x]
1476 (diff_type, diff_text) = diffs[x]
1477 if len(patch.diffs) == 0 and diff_type != self.DIFF_EQUAL:
1477 if len(patch.diffs) == 0 and diff_type != self.DIFF_EQUAL:
1478 # A new patch starts here.
1478 # A new patch starts here.
1479 patch.start1 = char_count1
1479 patch.start1 = char_count1
1480 patch.start2 = char_count2
1480 patch.start2 = char_count2
1481 if diff_type == self.DIFF_INSERT:
1481 if diff_type == self.DIFF_INSERT:
1482 # Insertion
1482 # Insertion
1483 patch.diffs.append(diffs[x])
1483 patch.diffs.append(diffs[x])
1484 patch.length2 += len(diff_text)
1484 patch.length2 += len(diff_text)
1485 postpatch_text = (postpatch_text[:char_count2] + diff_text +
1485 postpatch_text = (postpatch_text[:char_count2] + diff_text +
1486 postpatch_text[char_count2:])
1486 postpatch_text[char_count2:])
1487 elif diff_type == self.DIFF_DELETE:
1487 elif diff_type == self.DIFF_DELETE:
1488 # Deletion.
1488 # Deletion.
1489 patch.length1 += len(diff_text)
1489 patch.length1 += len(diff_text)
1490 patch.diffs.append(diffs[x])
1490 patch.diffs.append(diffs[x])
1491 postpatch_text = (postpatch_text[:char_count2] +
1491 postpatch_text = (postpatch_text[:char_count2] +
1492 postpatch_text[char_count2 + len(diff_text):])
1492 postpatch_text[char_count2 + len(diff_text):])
1493 elif (diff_type == self.DIFF_EQUAL and
1493 elif (diff_type == self.DIFF_EQUAL and
1494 len(diff_text) <= 2 * self.Patch_Margin and
1494 len(diff_text) <= 2 * self.Patch_Margin and
1495 len(patch.diffs) != 0 and len(diffs) != x + 1):
1495 len(patch.diffs) != 0 and len(diffs) != x + 1):
1496 # Small equality inside a patch.
1496 # Small equality inside a patch.
1497 patch.diffs.append(diffs[x])
1497 patch.diffs.append(diffs[x])
1498 patch.length1 += len(diff_text)
1498 patch.length1 += len(diff_text)
1499 patch.length2 += len(diff_text)
1499 patch.length2 += len(diff_text)
1500
1500
1501 if (diff_type == self.DIFF_EQUAL and
1501 if (diff_type == self.DIFF_EQUAL and
1502 len(diff_text) >= 2 * self.Patch_Margin):
1502 len(diff_text) >= 2 * self.Patch_Margin):
1503 # Time for a new patch.
1503 # Time for a new patch.
1504 if len(patch.diffs) != 0:
1504 if len(patch.diffs) != 0:
1505 self.patch_addContext(patch, prepatch_text)
1505 self.patch_addContext(patch, prepatch_text)
1506 patches.append(patch)
1506 patches.append(patch)
1507 patch = patch_obj()
1507 patch = patch_obj()
1508 # Unlike Unidiff, our patch lists have a rolling context.
1508 # Unlike Unidiff, our patch lists have a rolling context.
1509 # http://code.google.com/p/google-diff-match-patch/wiki/Unidiff
1509 # http://code.google.com/p/google-diff-match-patch/wiki/Unidiff
1510 # Update prepatch text & pos to reflect the application of the
1510 # Update prepatch text & pos to reflect the application of the
1511 # just completed patch.
1511 # just completed patch.
1512 prepatch_text = postpatch_text
1512 prepatch_text = postpatch_text
1513 char_count1 = char_count2
1513 char_count1 = char_count2
1514
1514
1515 # Update the current character count.
1515 # Update the current character count.
1516 if diff_type != self.DIFF_INSERT:
1516 if diff_type != self.DIFF_INSERT:
1517 char_count1 += len(diff_text)
1517 char_count1 += len(diff_text)
1518 if diff_type != self.DIFF_DELETE:
1518 if diff_type != self.DIFF_DELETE:
1519 char_count2 += len(diff_text)
1519 char_count2 += len(diff_text)
1520
1520
1521 # Pick up the leftover patch if not empty.
1521 # Pick up the leftover patch if not empty.
1522 if len(patch.diffs) != 0:
1522 if len(patch.diffs) != 0:
1523 self.patch_addContext(patch, prepatch_text)
1523 self.patch_addContext(patch, prepatch_text)
1524 patches.append(patch)
1524 patches.append(patch)
1525 return patches
1525 return patches
1526
1526
1527 def patch_deepCopy(self, patches):
1527 def patch_deepCopy(self, patches):
1528 """Given an array of patches, return another array that is identical.
1528 """Given an array of patches, return another array that is identical.
1529
1529
1530 Args:
1530 Args:
1531 patches: Array of Patch objects.
1531 patches: Array of Patch objects.
1532
1532
1533 Returns:
1533 Returns:
1534 Array of Patch objects.
1534 Array of Patch objects.
1535 """
1535 """
1536 patchesCopy = []
1536 patchesCopy = []
1537 for patch in patches:
1537 for patch in patches:
1538 patchCopy = patch_obj()
1538 patchCopy = patch_obj()
1539 # No need to deep copy the tuples since they are immutable.
1539 # No need to deep copy the tuples since they are immutable.
1540 patchCopy.diffs = patch.diffs[:]
1540 patchCopy.diffs = patch.diffs[:]
1541 patchCopy.start1 = patch.start1
1541 patchCopy.start1 = patch.start1
1542 patchCopy.start2 = patch.start2
1542 patchCopy.start2 = patch.start2
1543 patchCopy.length1 = patch.length1
1543 patchCopy.length1 = patch.length1
1544 patchCopy.length2 = patch.length2
1544 patchCopy.length2 = patch.length2
1545 patchesCopy.append(patchCopy)
1545 patchesCopy.append(patchCopy)
1546 return patchesCopy
1546 return patchesCopy
1547
1547
1548 def patch_apply(self, patches, text):
1548 def patch_apply(self, patches, text):
1549 """Merge a set of patches onto the text. Return a patched text, as well
1549 """Merge a set of patches onto the text. Return a patched text, as well
1550 as a list of true/false values indicating which patches were applied.
1550 as a list of true/false values indicating which patches were applied.
1551
1551
1552 Args:
1552 Args:
1553 patches: Array of Patch objects.
1553 patches: Array of Patch objects.
1554 text: Old text.
1554 text: Old text.
1555
1555
1556 Returns:
1556 Returns:
1557 Two element Array, containing the new text and an array of boolean values.
1557 Two element Array, containing the new text and an array of boolean values.
1558 """
1558 """
1559 if not patches:
1559 if not patches:
1560 return (text, [])
1560 return (text, [])
1561
1561
1562 # Deep copy the patches so that no changes are made to originals.
1562 # Deep copy the patches so that no changes are made to originals.
1563 patches = self.patch_deepCopy(patches)
1563 patches = self.patch_deepCopy(patches)
1564
1564
1565 nullPadding = self.patch_addPadding(patches)
1565 nullPadding = self.patch_addPadding(patches)
1566 text = nullPadding + text + nullPadding
1566 text = nullPadding + text + nullPadding
1567 self.patch_splitMax(patches)
1567 self.patch_splitMax(patches)
1568
1568
1569 # delta keeps track of the offset between the expected and actual location
1569 # delta keeps track of the offset between the expected and actual location
1570 # of the previous patch. If there are patches expected at positions 10 and
1570 # of the previous patch. If there are patches expected at positions 10 and
1571 # 20, but the first patch was found at 12, delta is 2 and the second patch
1571 # 20, but the first patch was found at 12, delta is 2 and the second patch
1572 # has an effective expected position of 22.
1572 # has an effective expected position of 22.
1573 delta = 0
1573 delta = 0
1574 results = []
1574 results = []
1575 for patch in patches:
1575 for patch in patches:
1576 expected_loc = patch.start2 + delta
1576 expected_loc = patch.start2 + delta
1577 text1 = self.diff_text1(patch.diffs)
1577 text1 = self.diff_text1(patch.diffs)
1578 end_loc = -1
1578 end_loc = -1
1579 if len(text1) > self.Match_MaxBits:
1579 if len(text1) > self.Match_MaxBits:
1580 # patch_splitMax will only provide an oversized pattern in the case of
1580 # patch_splitMax will only provide an oversized pattern in the case of
1581 # a monster delete.
1581 # a monster delete.
1582 start_loc = self.match_main(text, text1[:self.Match_MaxBits],
1582 start_loc = self.match_main(text, text1[:self.Match_MaxBits],
1583 expected_loc)
1583 expected_loc)
1584 if start_loc != -1:
1584 if start_loc != -1:
1585 end_loc = self.match_main(text, text1[-self.Match_MaxBits:],
1585 end_loc = self.match_main(text, text1[-self.Match_MaxBits:],
1586 expected_loc + len(text1) - self.Match_MaxBits)
1586 expected_loc + len(text1) - self.Match_MaxBits)
1587 if end_loc == -1 or start_loc >= end_loc:
1587 if end_loc == -1 or start_loc >= end_loc:
1588 # Can't find valid trailing context. Drop this patch.
1588 # Can't find valid trailing context. Drop this patch.
1589 start_loc = -1
1589 start_loc = -1
1590 else:
1590 else:
1591 start_loc = self.match_main(text, text1, expected_loc)
1591 start_loc = self.match_main(text, text1, expected_loc)
1592 if start_loc == -1:
1592 if start_loc == -1:
1593 # No match found. :(
1593 # No match found. :(
1594 results.append(False)
1594 results.append(False)
1595 # Subtract the delta for this failed patch from subsequent patches.
1595 # Subtract the delta for this failed patch from subsequent patches.
1596 delta -= patch.length2 - patch.length1
1596 delta -= patch.length2 - patch.length1
1597 else:
1597 else:
1598 # Found a match. :)
1598 # Found a match. :)
1599 results.append(True)
1599 results.append(True)
1600 delta = start_loc - expected_loc
1600 delta = start_loc - expected_loc
1601 if end_loc == -1:
1601 if end_loc == -1:
1602 text2 = text[start_loc : start_loc + len(text1)]
1602 text2 = text[start_loc : start_loc + len(text1)]
1603 else:
1603 else:
1604 text2 = text[start_loc : end_loc + self.Match_MaxBits]
1604 text2 = text[start_loc : end_loc + self.Match_MaxBits]
1605 if text1 == text2:
1605 if text1 == text2:
1606 # Perfect match, just shove the replacement text in.
1606 # Perfect match, just shove the replacement text in.
1607 text = (text[:start_loc] + self.diff_text2(patch.diffs) +
1607 text = (text[:start_loc] + self.diff_text2(patch.diffs) +
1608 text[start_loc + len(text1):])
1608 text[start_loc + len(text1):])
1609 else:
1609 else:
1610 # Imperfect match.
1610 # Imperfect match.
1611 # Run a diff to get a framework of equivalent indices.
1611 # Run a diff to get a framework of equivalent indices.
1612 diffs = self.diff_main(text1, text2, False)
1612 diffs = self.diff_main(text1, text2, False)
1613 if (len(text1) > self.Match_MaxBits and
1613 if (len(text1) > self.Match_MaxBits and
1614 self.diff_levenshtein(diffs) / float(len(text1)) >
1614 self.diff_levenshtein(diffs) / float(len(text1)) >
1615 self.Patch_DeleteThreshold):
1615 self.Patch_DeleteThreshold):
1616 # The end points match, but the content is unacceptably bad.
1616 # The end points match, but the content is unacceptably bad.
1617 results[-1] = False
1617 results[-1] = False
1618 else:
1618 else:
1619 self.diff_cleanupSemanticLossless(diffs)
1619 self.diff_cleanupSemanticLossless(diffs)
1620 index1 = 0
1620 index1 = 0
1621 for (op, data) in patch.diffs:
1621 for (op, data) in patch.diffs:
1622 if op != self.DIFF_EQUAL:
1622 if op != self.DIFF_EQUAL:
1623 index2 = self.diff_xIndex(diffs, index1)
1623 index2 = self.diff_xIndex(diffs, index1)
1624 if op == self.DIFF_INSERT: # Insertion
1624 if op == self.DIFF_INSERT: # Insertion
1625 text = text[:start_loc + index2] + data + text[start_loc +
1625 text = text[:start_loc + index2] + data + text[start_loc +
1626 index2:]
1626 index2:]
1627 elif op == self.DIFF_DELETE: # Deletion
1627 elif op == self.DIFF_DELETE: # Deletion
1628 text = text[:start_loc + index2] + text[start_loc +
1628 text = text[:start_loc + index2] + text[start_loc +
1629 self.diff_xIndex(diffs, index1 + len(data)):]
1629 self.diff_xIndex(diffs, index1 + len(data)):]
1630 if op != self.DIFF_DELETE:
1630 if op != self.DIFF_DELETE:
1631 index1 += len(data)
1631 index1 += len(data)
1632 # Strip the padding off.
1632 # Strip the padding off.
1633 text = text[len(nullPadding):-len(nullPadding)]
1633 text = text[len(nullPadding):-len(nullPadding)]
1634 return (text, results)
1634 return (text, results)
1635
1635
1636 def patch_addPadding(self, patches):
1636 def patch_addPadding(self, patches):
1637 """Add some padding on text start and end so that edges can match
1637 """Add some padding on text start and end so that edges can match
1638 something. Intended to be called only from within patch_apply.
1638 something. Intended to be called only from within patch_apply.
1639
1639
1640 Args:
1640 Args:
1641 patches: Array of Patch objects.
1641 patches: Array of Patch objects.
1642
1642
1643 Returns:
1643 Returns:
1644 The padding string added to each side.
1644 The padding string added to each side.
1645 """
1645 """
1646 paddingLength = self.Patch_Margin
1646 paddingLength = self.Patch_Margin
1647 nullPadding = ""
1647 nullPadding = ""
1648 for x in xrange(1, paddingLength + 1):
1648 for x in xrange(1, paddingLength + 1):
1649 nullPadding += chr(x)
1649 nullPadding += chr(x)
1650
1650
1651 # Bump all the patches forward.
1651 # Bump all the patches forward.
1652 for patch in patches:
1652 for patch in patches:
1653 patch.start1 += paddingLength
1653 patch.start1 += paddingLength
1654 patch.start2 += paddingLength
1654 patch.start2 += paddingLength
1655
1655
1656 # Add some padding on start of first diff.
1656 # Add some padding on start of first diff.
1657 patch = patches[0]
1657 patch = patches[0]
1658 diffs = patch.diffs
1658 diffs = patch.diffs
1659 if not diffs or diffs[0][0] != self.DIFF_EQUAL:
1659 if not diffs or diffs[0][0] != self.DIFF_EQUAL:
1660 # Add nullPadding equality.
1660 # Add nullPadding equality.
1661 diffs.insert(0, (self.DIFF_EQUAL, nullPadding))
1661 diffs.insert(0, (self.DIFF_EQUAL, nullPadding))
1662 patch.start1 -= paddingLength # Should be 0.
1662 patch.start1 -= paddingLength # Should be 0.
1663 patch.start2 -= paddingLength # Should be 0.
1663 patch.start2 -= paddingLength # Should be 0.
1664 patch.length1 += paddingLength
1664 patch.length1 += paddingLength
1665 patch.length2 += paddingLength
1665 patch.length2 += paddingLength
1666 elif paddingLength > len(diffs[0][1]):
1666 elif paddingLength > len(diffs[0][1]):
1667 # Grow first equality.
1667 # Grow first equality.
1668 extraLength = paddingLength - len(diffs[0][1])
1668 extraLength = paddingLength - len(diffs[0][1])
1669 newText = nullPadding[len(diffs[0][1]):] + diffs[0][1]
1669 newText = nullPadding[len(diffs[0][1]):] + diffs[0][1]
1670 diffs[0] = (diffs[0][0], newText)
1670 diffs[0] = (diffs[0][0], newText)
1671 patch.start1 -= extraLength
1671 patch.start1 -= extraLength
1672 patch.start2 -= extraLength
1672 patch.start2 -= extraLength
1673 patch.length1 += extraLength
1673 patch.length1 += extraLength
1674 patch.length2 += extraLength
1674 patch.length2 += extraLength
1675
1675
1676 # Add some padding on end of last diff.
1676 # Add some padding on end of last diff.
1677 patch = patches[-1]
1677 patch = patches[-1]
1678 diffs = patch.diffs
1678 diffs = patch.diffs
1679 if not diffs or diffs[-1][0] != self.DIFF_EQUAL:
1679 if not diffs or diffs[-1][0] != self.DIFF_EQUAL:
1680 # Add nullPadding equality.
1680 # Add nullPadding equality.
1681 diffs.append((self.DIFF_EQUAL, nullPadding))
1681 diffs.append((self.DIFF_EQUAL, nullPadding))
1682 patch.length1 += paddingLength
1682 patch.length1 += paddingLength
1683 patch.length2 += paddingLength
1683 patch.length2 += paddingLength
1684 elif paddingLength > len(diffs[-1][1]):
1684 elif paddingLength > len(diffs[-1][1]):
1685 # Grow last equality.
1685 # Grow last equality.
1686 extraLength = paddingLength - len(diffs[-1][1])
1686 extraLength = paddingLength - len(diffs[-1][1])
1687 newText = diffs[-1][1] + nullPadding[:extraLength]
1687 newText = diffs[-1][1] + nullPadding[:extraLength]
1688 diffs[-1] = (diffs[-1][0], newText)
1688 diffs[-1] = (diffs[-1][0], newText)
1689 patch.length1 += extraLength
1689 patch.length1 += extraLength
1690 patch.length2 += extraLength
1690 patch.length2 += extraLength
1691
1691
1692 return nullPadding
1692 return nullPadding
1693
1693
1694 def patch_splitMax(self, patches):
1694 def patch_splitMax(self, patches):
1695 """Look through the patches and break up any which are longer than the
1695 """Look through the patches and break up any which are longer than the
1696 maximum limit of the match algorithm.
1696 maximum limit of the match algorithm.
1697 Intended to be called only from within patch_apply.
1697 Intended to be called only from within patch_apply.
1698
1698
1699 Args:
1699 Args:
1700 patches: Array of Patch objects.
1700 patches: Array of Patch objects.
1701 """
1701 """
1702 patch_size = self.Match_MaxBits
1702 patch_size = self.Match_MaxBits
1703 if patch_size == 0:
1703 if patch_size == 0:
1704 # Python has the option of not splitting strings due to its ability
1704 # Python has the option of not splitting strings due to its ability
1705 # to handle integers of arbitrary precision.
1705 # to handle integers of arbitrary precision.
1706 return
1706 return
1707 for x in xrange(len(patches)):
1707 for x in xrange(len(patches)):
1708 if patches[x].length1 <= patch_size:
1708 if patches[x].length1 <= patch_size:
1709 continue
1709 continue
1710 bigpatch = patches[x]
1710 bigpatch = patches[x]
1711 # Remove the big old patch.
1711 # Remove the big old patch.
1712 del patches[x]
1712 del patches[x]
1713 x -= 1
1713 x -= 1
1714 start1 = bigpatch.start1
1714 start1 = bigpatch.start1
1715 start2 = bigpatch.start2
1715 start2 = bigpatch.start2
1716 precontext = ''
1716 precontext = ''
1717 while len(bigpatch.diffs) != 0:
1717 while len(bigpatch.diffs) != 0:
1718 # Create one of several smaller patches.
1718 # Create one of several smaller patches.
1719 patch = patch_obj()
1719 patch = patch_obj()
1720 empty = True
1720 empty = True
1721 patch.start1 = start1 - len(precontext)
1721 patch.start1 = start1 - len(precontext)
1722 patch.start2 = start2 - len(precontext)
1722 patch.start2 = start2 - len(precontext)
1723 if precontext:
1723 if precontext:
1724 patch.length1 = patch.length2 = len(precontext)
1724 patch.length1 = patch.length2 = len(precontext)
1725 patch.diffs.append((self.DIFF_EQUAL, precontext))
1725 patch.diffs.append((self.DIFF_EQUAL, precontext))
1726
1726
1727 while (len(bigpatch.diffs) != 0 and
1727 while (len(bigpatch.diffs) != 0 and
1728 patch.length1 < patch_size - self.Patch_Margin):
1728 patch.length1 < patch_size - self.Patch_Margin):
1729 (diff_type, diff_text) = bigpatch.diffs[0]
1729 (diff_type, diff_text) = bigpatch.diffs[0]
1730 if diff_type == self.DIFF_INSERT:
1730 if diff_type == self.DIFF_INSERT:
1731 # Insertions are harmless.
1731 # Insertions are harmless.
1732 patch.length2 += len(diff_text)
1732 patch.length2 += len(diff_text)
1733 start2 += len(diff_text)
1733 start2 += len(diff_text)
1734 patch.diffs.append(bigpatch.diffs.pop(0))
1734 patch.diffs.append(bigpatch.diffs.pop(0))
1735 empty = False
1735 empty = False
1736 elif (diff_type == self.DIFF_DELETE and len(patch.diffs) == 1 and
1736 elif (diff_type == self.DIFF_DELETE and len(patch.diffs) == 1 and
1737 patch.diffs[0][0] == self.DIFF_EQUAL and
1737 patch.diffs[0][0] == self.DIFF_EQUAL and
1738 len(diff_text) > 2 * patch_size):
1738 len(diff_text) > 2 * patch_size):
1739 # This is a large deletion. Let it pass in one chunk.
1739 # This is a large deletion. Let it pass in one chunk.
1740 patch.length1 += len(diff_text)
1740 patch.length1 += len(diff_text)
1741 start1 += len(diff_text)
1741 start1 += len(diff_text)
1742 empty = False
1742 empty = False
1743 patch.diffs.append((diff_type, diff_text))
1743 patch.diffs.append((diff_type, diff_text))
1744 del bigpatch.diffs[0]
1744 del bigpatch.diffs[0]
1745 else:
1745 else:
1746 # Deletion or equality. Only take as much as we can stomach.
1746 # Deletion or equality. Only take as much as we can stomach.
1747 diff_text = diff_text[:patch_size - patch.length1 -
1747 diff_text = diff_text[:patch_size - patch.length1 -
1748 self.Patch_Margin]
1748 self.Patch_Margin]
1749 patch.length1 += len(diff_text)
1749 patch.length1 += len(diff_text)
1750 start1 += len(diff_text)
1750 start1 += len(diff_text)
1751 if diff_type == self.DIFF_EQUAL:
1751 if diff_type == self.DIFF_EQUAL:
1752 patch.length2 += len(diff_text)
1752 patch.length2 += len(diff_text)
1753 start2 += len(diff_text)
1753 start2 += len(diff_text)
1754 else:
1754 else:
1755 empty = False
1755 empty = False
1756
1756
1757 patch.diffs.append((diff_type, diff_text))
1757 patch.diffs.append((diff_type, diff_text))
1758 if diff_text == bigpatch.diffs[0][1]:
1758 if diff_text == bigpatch.diffs[0][1]:
1759 del bigpatch.diffs[0]
1759 del bigpatch.diffs[0]
1760 else:
1760 else:
1761 bigpatch.diffs[0] = (bigpatch.diffs[0][0],
1761 bigpatch.diffs[0] = (bigpatch.diffs[0][0],
1762 bigpatch.diffs[0][1][len(diff_text):])
1762 bigpatch.diffs[0][1][len(diff_text):])
1763
1763
1764 # Compute the head context for the next patch.
1764 # Compute the head context for the next patch.
1765 precontext = self.diff_text2(patch.diffs)
1765 precontext = self.diff_text2(patch.diffs)
1766 precontext = precontext[-self.Patch_Margin:]
1766 precontext = precontext[-self.Patch_Margin:]
1767 # Append the end context for this patch.
1767 # Append the end context for this patch.
1768 postcontext = self.diff_text1(bigpatch.diffs)[:self.Patch_Margin]
1768 postcontext = self.diff_text1(bigpatch.diffs)[:self.Patch_Margin]
1769 if postcontext:
1769 if postcontext:
1770 patch.length1 += len(postcontext)
1770 patch.length1 += len(postcontext)
1771 patch.length2 += len(postcontext)
1771 patch.length2 += len(postcontext)
1772 if len(patch.diffs) != 0 and patch.diffs[-1][0] == self.DIFF_EQUAL:
1772 if len(patch.diffs) != 0 and patch.diffs[-1][0] == self.DIFF_EQUAL:
1773 patch.diffs[-1] = (self.DIFF_EQUAL, patch.diffs[-1][1] +
1773 patch.diffs[-1] = (self.DIFF_EQUAL, patch.diffs[-1][1] +
1774 postcontext)
1774 postcontext)
1775 else:
1775 else:
1776 patch.diffs.append((self.DIFF_EQUAL, postcontext))
1776 patch.diffs.append((self.DIFF_EQUAL, postcontext))
1777
1777
1778 if not empty:
1778 if not empty:
1779 x += 1
1779 x += 1
1780 patches.insert(x, patch)
1780 patches.insert(x, patch)
1781
1781
1782 def patch_toText(self, patches):
1782 def patch_toText(self, patches):
1783 """Take a list of patches and return a textual representation.
1783 """Take a list of patches and return a textual representation.
1784
1784
1785 Args:
1785 Args:
1786 patches: Array of Patch objects.
1786 patches: Array of Patch objects.
1787
1787
1788 Returns:
1788 Returns:
1789 Text representation of patches.
1789 Text representation of patches.
1790 """
1790 """
1791 text = []
1791 text = []
1792 for patch in patches:
1792 for patch in patches:
1793 text.append(str(patch))
1793 text.append(str(patch))
1794 return "".join(text)
1794 return "".join(text)
1795
1795
1796 def patch_fromText(self, textline):
1796 def patch_fromText(self, textline):
1797 """Parse a textual representation of patches and return a list of patch
1797 """Parse a textual representation of patches and return a list of patch
1798 objects.
1798 objects.
1799
1799
1800 Args:
1800 Args:
1801 textline: Text representation of patches.
1801 textline: Text representation of patches.
1802
1802
1803 Returns:
1803 Returns:
1804 Array of Patch objects.
1804 Array of Patch objects.
1805
1805
1806 Raises:
1806 Raises:
1807 ValueError: If invalid input.
1807 ValueError: If invalid input.
1808 """
1808 """
1809 if type(textline) == unicode:
1809 if type(textline) == unicode:
1810 # Patches should be composed of a subset of ascii chars, Unicode not
1810 # Patches should be composed of a subset of ascii chars, Unicode not
1811 # required. If this encode raises UnicodeEncodeError, patch is invalid.
1811 # required. If this encode raises UnicodeEncodeError, patch is invalid.
1812 textline = textline.encode("ascii")
1812 textline = textline.encode("ascii")
1813 patches = []
1813 patches = []
1814 if not textline:
1814 if not textline:
1815 return patches
1815 return patches
1816 text = textline.split('\n')
1816 text = textline.split('\n')
1817 while len(text) != 0:
1817 while len(text) != 0:
1818 m = re.match("^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@$", text[0])
1818 m = re.match("^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@$", text[0])
1819 if not m:
1819 if not m:
1820 raise ValueError("Invalid patch string: " + text[0])
1820 raise ValueError("Invalid patch string: " + text[0])
1821 patch = patch_obj()
1821 patch = patch_obj()
1822 patches.append(patch)
1822 patches.append(patch)
1823 patch.start1 = int(m.group(1))
1823 patch.start1 = int(m.group(1))
1824 if m.group(2) == '':
1824 if m.group(2) == '':
1825 patch.start1 -= 1
1825 patch.start1 -= 1
1826 patch.length1 = 1
1826 patch.length1 = 1
1827 elif m.group(2) == '0':
1827 elif m.group(2) == '0':
1828 patch.length1 = 0
1828 patch.length1 = 0
1829 else:
1829 else:
1830 patch.start1 -= 1
1830 patch.start1 -= 1
1831 patch.length1 = int(m.group(2))
1831 patch.length1 = int(m.group(2))
1832
1832
1833 patch.start2 = int(m.group(3))
1833 patch.start2 = int(m.group(3))
1834 if m.group(4) == '':
1834 if m.group(4) == '':
1835 patch.start2 -= 1
1835 patch.start2 -= 1
1836 patch.length2 = 1
1836 patch.length2 = 1
1837 elif m.group(4) == '0':
1837 elif m.group(4) == '0':
1838 patch.length2 = 0
1838 patch.length2 = 0
1839 else:
1839 else:
1840 patch.start2 -= 1
1840 patch.start2 -= 1
1841 patch.length2 = int(m.group(4))
1841 patch.length2 = int(m.group(4))
1842
1842
1843 del text[0]
1843 del text[0]
1844
1844
1845 while len(text) != 0:
1845 while len(text) != 0:
1846 if text[0]:
1846 if text[0]:
1847 sign = text[0][0]
1847 sign = text[0][0]
1848 else:
1848 else:
1849 sign = ''
1849 sign = ''
1850 line = urllib.unquote(text[0][1:])
1850 line = urllib.unquote(text[0][1:])
1851 line = line.decode("utf-8")
1851 line = line.decode("utf-8")
1852 if sign == '+':
1852 if sign == '+':
1853 # Insertion.
1853 # Insertion.
1854 patch.diffs.append((self.DIFF_INSERT, line))
1854 patch.diffs.append((self.DIFF_INSERT, line))
1855 elif sign == '-':
1855 elif sign == '-':
1856 # Deletion.
1856 # Deletion.
1857 patch.diffs.append((self.DIFF_DELETE, line))
1857 patch.diffs.append((self.DIFF_DELETE, line))
1858 elif sign == ' ':
1858 elif sign == ' ':
1859 # Minor equality.
1859 # Minor equality.
1860 patch.diffs.append((self.DIFF_EQUAL, line))
1860 patch.diffs.append((self.DIFF_EQUAL, line))
1861 elif sign == '@':
1861 elif sign == '@':
1862 # Start of next patch.
1862 # Start of next patch.
1863 break
1863 break
1864 elif sign == '':
1864 elif sign == '':
1865 # Blank line? Whatever.
1865 # Blank line? Whatever.
1866 pass
1866 pass
1867 else:
1867 else:
1868 # WTF?
1868 # WTF?
1869 raise ValueError("Invalid patch mode: '%s'\n%s" % (sign, line))
1869 raise ValueError("Invalid patch mode: '%s'\n%s" % (sign, line))
1870 del text[0]
1870 del text[0]
1871 return patches
1871 return patches
1872
1872
1873
1873
1874 class patch_obj:
1874 class patch_obj:
1875 """Class representing one patch operation.
1875 """Class representing one patch operation.
1876 """
1876 """
1877
1877
1878 def __init__(self):
1878 def __init__(self):
1879 """Initializes with an empty list of diffs.
1879 """Initializes with an empty list of diffs.
1880 """
1880 """
1881 self.diffs = []
1881 self.diffs = []
1882 self.start1 = None
1882 self.start1 = None
1883 self.start2 = None
1883 self.start2 = None
1884 self.length1 = 0
1884 self.length1 = 0
1885 self.length2 = 0
1885 self.length2 = 0
1886
1886
1887 def __str__(self):
1887 def __str__(self):
1888 """Emmulate GNU diff's format.
1888 """Emmulate GNU diff's format.
1889 Header: @@ -382,8 +481,9 @@
1889 Header: @@ -382,8 +481,9 @@
1890 Indicies are printed as 1-based, not 0-based.
1890 Indicies are printed as 1-based, not 0-based.
1891
1891
1892 Returns:
1892 Returns:
1893 The GNU diff string.
1893 The GNU diff string.
1894 """
1894 """
1895 if self.length1 == 0:
1895 if self.length1 == 0:
1896 coords1 = str(self.start1) + ",0"
1896 coords1 = str(self.start1) + ",0"
1897 elif self.length1 == 1:
1897 elif self.length1 == 1:
1898 coords1 = str(self.start1 + 1)
1898 coords1 = str(self.start1 + 1)
1899 else:
1899 else:
1900 coords1 = str(self.start1 + 1) + "," + str(self.length1)
1900 coords1 = str(self.start1 + 1) + "," + str(self.length1)
1901 if self.length2 == 0:
1901 if self.length2 == 0:
1902 coords2 = str(self.start2) + ",0"
1902 coords2 = str(self.start2) + ",0"
1903 elif self.length2 == 1:
1903 elif self.length2 == 1:
1904 coords2 = str(self.start2 + 1)
1904 coords2 = str(self.start2 + 1)
1905 else:
1905 else:
1906 coords2 = str(self.start2 + 1) + "," + str(self.length2)
1906 coords2 = str(self.start2 + 1) + "," + str(self.length2)
1907 text = ["@@ -", coords1, " +", coords2, " @@\n"]
1907 text = ["@@ -", coords1, " +", coords2, " @@\n"]
1908 # Escape the body of the patch with %xx notation.
1908 # Escape the body of the patch with %xx notation.
1909 for (op, data) in self.diffs:
1909 for (op, data) in self.diffs:
1910 if op == diff_match_patch.DIFF_INSERT:
1910 if op == diff_match_patch.DIFF_INSERT:
1911 text.append("+")
1911 text.append("+")
1912 elif op == diff_match_patch.DIFF_DELETE:
1912 elif op == diff_match_patch.DIFF_DELETE:
1913 text.append("-")
1913 text.append("-")
1914 elif op == diff_match_patch.DIFF_EQUAL:
1914 elif op == diff_match_patch.DIFF_EQUAL:
1915 text.append(" ")
1915 text.append(" ")
1916 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1916 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1917 data = data.encode("utf-8")
1917 data = data.encode("utf-8")
1918 text.append(urllib.quote(data, "!~*'();/?:@&=+$,# ") + "\n")
1918 text.append(urllib.quote(data, "!~*'();/?:@&=+$,# ") + "\n")
1919 return "".join(text) No newline at end of file
1919 return "".join(text)
@@ -1,1739 +1,1737 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2012-2018 RhodeCode GmbH
3 # Copyright (C) 2012-2018 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 pull request model for RhodeCode
23 pull request model for RhodeCode
24 """
24 """
25
25
26
26
27 import json
27 import json
28 import logging
28 import logging
29 import datetime
29 import datetime
30 import urllib
30 import urllib
31 import collections
31 import collections
32
32
33 from pyramid.threadlocal import get_current_request
33 from pyramid.threadlocal import get_current_request
34
34
35 from rhodecode import events
35 from rhodecode import events
36 from rhodecode.translation import lazy_ugettext#, _
36 from rhodecode.translation import lazy_ugettext#, _
37 from rhodecode.lib import helpers as h, hooks_utils, diffs
37 from rhodecode.lib import helpers as h, hooks_utils, diffs
38 from rhodecode.lib import audit_logger
38 from rhodecode.lib import audit_logger
39 from rhodecode.lib.compat import OrderedDict
39 from rhodecode.lib.compat import OrderedDict
40 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
40 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
41 from rhodecode.lib.markup_renderer import (
41 from rhodecode.lib.markup_renderer import (
42 DEFAULT_COMMENTS_RENDERER, RstTemplateRenderer)
42 DEFAULT_COMMENTS_RENDERER, RstTemplateRenderer)
43 from rhodecode.lib.utils2 import safe_unicode, safe_str, md5_safe
43 from rhodecode.lib.utils2 import safe_unicode, safe_str, md5_safe
44 from rhodecode.lib.vcs.backends.base import (
44 from rhodecode.lib.vcs.backends.base import (
45 Reference, MergeResponse, MergeFailureReason, UpdateFailureReason)
45 Reference, MergeResponse, MergeFailureReason, UpdateFailureReason)
46 from rhodecode.lib.vcs.conf import settings as vcs_settings
46 from rhodecode.lib.vcs.conf import settings as vcs_settings
47 from rhodecode.lib.vcs.exceptions import (
47 from rhodecode.lib.vcs.exceptions import (
48 CommitDoesNotExistError, EmptyRepositoryError)
48 CommitDoesNotExistError, EmptyRepositoryError)
49 from rhodecode.model import BaseModel
49 from rhodecode.model import BaseModel
50 from rhodecode.model.changeset_status import ChangesetStatusModel
50 from rhodecode.model.changeset_status import ChangesetStatusModel
51 from rhodecode.model.comment import CommentsModel
51 from rhodecode.model.comment import CommentsModel
52 from rhodecode.model.db import (
52 from rhodecode.model.db import (
53 or_, PullRequest, PullRequestReviewers, ChangesetStatus,
53 or_, PullRequest, PullRequestReviewers, ChangesetStatus,
54 PullRequestVersion, ChangesetComment, Repository, RepoReviewRule)
54 PullRequestVersion, ChangesetComment, Repository, RepoReviewRule)
55 from rhodecode.model.meta import Session
55 from rhodecode.model.meta import Session
56 from rhodecode.model.notification import NotificationModel, \
56 from rhodecode.model.notification import NotificationModel, \
57 EmailNotificationModel
57 EmailNotificationModel
58 from rhodecode.model.scm import ScmModel
58 from rhodecode.model.scm import ScmModel
59 from rhodecode.model.settings import VcsSettingsModel
59 from rhodecode.model.settings import VcsSettingsModel
60
60
61
61
62 log = logging.getLogger(__name__)
62 log = logging.getLogger(__name__)
63
63
64
64
65 # Data structure to hold the response data when updating commits during a pull
65 # Data structure to hold the response data when updating commits during a pull
66 # request update.
66 # request update.
67 UpdateResponse = collections.namedtuple('UpdateResponse', [
67 UpdateResponse = collections.namedtuple('UpdateResponse', [
68 'executed', 'reason', 'new', 'old', 'changes',
68 'executed', 'reason', 'new', 'old', 'changes',
69 'source_changed', 'target_changed'])
69 'source_changed', 'target_changed'])
70
70
71
71
72 class PullRequestModel(BaseModel):
72 class PullRequestModel(BaseModel):
73
73
74 cls = PullRequest
74 cls = PullRequest
75
75
76 DIFF_CONTEXT = diffs.DEFAULT_CONTEXT
76 DIFF_CONTEXT = diffs.DEFAULT_CONTEXT
77
77
78 MERGE_STATUS_MESSAGES = {
78 MERGE_STATUS_MESSAGES = {
79 MergeFailureReason.NONE: lazy_ugettext(
79 MergeFailureReason.NONE: lazy_ugettext(
80 'This pull request can be automatically merged.'),
80 'This pull request can be automatically merged.'),
81 MergeFailureReason.UNKNOWN: lazy_ugettext(
81 MergeFailureReason.UNKNOWN: lazy_ugettext(
82 'This pull request cannot be merged because of an unhandled'
82 'This pull request cannot be merged because of an unhandled'
83 ' exception.'),
83 ' exception.'),
84 MergeFailureReason.MERGE_FAILED: lazy_ugettext(
84 MergeFailureReason.MERGE_FAILED: lazy_ugettext(
85 'This pull request cannot be merged because of merge conflicts.'),
85 'This pull request cannot be merged because of merge conflicts.'),
86 MergeFailureReason.PUSH_FAILED: lazy_ugettext(
86 MergeFailureReason.PUSH_FAILED: lazy_ugettext(
87 'This pull request could not be merged because push to target'
87 'This pull request could not be merged because push to target'
88 ' failed.'),
88 ' failed.'),
89 MergeFailureReason.TARGET_IS_NOT_HEAD: lazy_ugettext(
89 MergeFailureReason.TARGET_IS_NOT_HEAD: lazy_ugettext(
90 'This pull request cannot be merged because the target is not a'
90 'This pull request cannot be merged because the target is not a'
91 ' head.'),
91 ' head.'),
92 MergeFailureReason.HG_SOURCE_HAS_MORE_BRANCHES: lazy_ugettext(
92 MergeFailureReason.HG_SOURCE_HAS_MORE_BRANCHES: lazy_ugettext(
93 'This pull request cannot be merged because the source contains'
93 'This pull request cannot be merged because the source contains'
94 ' more branches than the target.'),
94 ' more branches than the target.'),
95 MergeFailureReason.HG_TARGET_HAS_MULTIPLE_HEADS: lazy_ugettext(
95 MergeFailureReason.HG_TARGET_HAS_MULTIPLE_HEADS: lazy_ugettext(
96 'This pull request cannot be merged because the target has'
96 'This pull request cannot be merged because the target has'
97 ' multiple heads.'),
97 ' multiple heads.'),
98 MergeFailureReason.TARGET_IS_LOCKED: lazy_ugettext(
98 MergeFailureReason.TARGET_IS_LOCKED: lazy_ugettext(
99 'This pull request cannot be merged because the target repository'
99 'This pull request cannot be merged because the target repository'
100 ' is locked.'),
100 ' is locked.'),
101 MergeFailureReason._DEPRECATED_MISSING_COMMIT: lazy_ugettext(
101 MergeFailureReason._DEPRECATED_MISSING_COMMIT: lazy_ugettext(
102 'This pull request cannot be merged because the target or the '
102 'This pull request cannot be merged because the target or the '
103 'source reference is missing.'),
103 'source reference is missing.'),
104 MergeFailureReason.MISSING_TARGET_REF: lazy_ugettext(
104 MergeFailureReason.MISSING_TARGET_REF: lazy_ugettext(
105 'This pull request cannot be merged because the target '
105 'This pull request cannot be merged because the target '
106 'reference is missing.'),
106 'reference is missing.'),
107 MergeFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
107 MergeFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
108 'This pull request cannot be merged because the source '
108 'This pull request cannot be merged because the source '
109 'reference is missing.'),
109 'reference is missing.'),
110 MergeFailureReason.SUBREPO_MERGE_FAILED: lazy_ugettext(
110 MergeFailureReason.SUBREPO_MERGE_FAILED: lazy_ugettext(
111 'This pull request cannot be merged because of conflicts related '
111 'This pull request cannot be merged because of conflicts related '
112 'to sub repositories.'),
112 'to sub repositories.'),
113 }
113 }
114
114
115 UPDATE_STATUS_MESSAGES = {
115 UPDATE_STATUS_MESSAGES = {
116 UpdateFailureReason.NONE: lazy_ugettext(
116 UpdateFailureReason.NONE: lazy_ugettext(
117 'Pull request update successful.'),
117 'Pull request update successful.'),
118 UpdateFailureReason.UNKNOWN: lazy_ugettext(
118 UpdateFailureReason.UNKNOWN: lazy_ugettext(
119 'Pull request update failed because of an unknown error.'),
119 'Pull request update failed because of an unknown error.'),
120 UpdateFailureReason.NO_CHANGE: lazy_ugettext(
120 UpdateFailureReason.NO_CHANGE: lazy_ugettext(
121 'No update needed because the source and target have not changed.'),
121 'No update needed because the source and target have not changed.'),
122 UpdateFailureReason.WRONG_REF_TYPE: lazy_ugettext(
122 UpdateFailureReason.WRONG_REF_TYPE: lazy_ugettext(
123 'Pull request cannot be updated because the reference type is '
123 'Pull request cannot be updated because the reference type is '
124 'not supported for an update. Only Branch, Tag or Bookmark is allowed.'),
124 'not supported for an update. Only Branch, Tag or Bookmark is allowed.'),
125 UpdateFailureReason.MISSING_TARGET_REF: lazy_ugettext(
125 UpdateFailureReason.MISSING_TARGET_REF: lazy_ugettext(
126 'This pull request cannot be updated because the target '
126 'This pull request cannot be updated because the target '
127 'reference is missing.'),
127 'reference is missing.'),
128 UpdateFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
128 UpdateFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
129 'This pull request cannot be updated because the source '
129 'This pull request cannot be updated because the source '
130 'reference is missing.'),
130 'reference is missing.'),
131 }
131 }
132
132
133 def __get_pull_request(self, pull_request):
133 def __get_pull_request(self, pull_request):
134 return self._get_instance((
134 return self._get_instance((
135 PullRequest, PullRequestVersion), pull_request)
135 PullRequest, PullRequestVersion), pull_request)
136
136
137 def _check_perms(self, perms, pull_request, user, api=False):
137 def _check_perms(self, perms, pull_request, user, api=False):
138 if not api:
138 if not api:
139 return h.HasRepoPermissionAny(*perms)(
139 return h.HasRepoPermissionAny(*perms)(
140 user=user, repo_name=pull_request.target_repo.repo_name)
140 user=user, repo_name=pull_request.target_repo.repo_name)
141 else:
141 else:
142 return h.HasRepoPermissionAnyApi(*perms)(
142 return h.HasRepoPermissionAnyApi(*perms)(
143 user=user, repo_name=pull_request.target_repo.repo_name)
143 user=user, repo_name=pull_request.target_repo.repo_name)
144
144
145 def check_user_read(self, pull_request, user, api=False):
145 def check_user_read(self, pull_request, user, api=False):
146 _perms = ('repository.admin', 'repository.write', 'repository.read',)
146 _perms = ('repository.admin', 'repository.write', 'repository.read',)
147 return self._check_perms(_perms, pull_request, user, api)
147 return self._check_perms(_perms, pull_request, user, api)
148
148
149 def check_user_merge(self, pull_request, user, api=False):
149 def check_user_merge(self, pull_request, user, api=False):
150 _perms = ('repository.admin', 'repository.write', 'hg.admin',)
150 _perms = ('repository.admin', 'repository.write', 'hg.admin',)
151 return self._check_perms(_perms, pull_request, user, api)
151 return self._check_perms(_perms, pull_request, user, api)
152
152
153 def check_user_update(self, pull_request, user, api=False):
153 def check_user_update(self, pull_request, user, api=False):
154 owner = user.user_id == pull_request.user_id
154 owner = user.user_id == pull_request.user_id
155 return self.check_user_merge(pull_request, user, api) or owner
155 return self.check_user_merge(pull_request, user, api) or owner
156
156
157 def check_user_delete(self, pull_request, user):
157 def check_user_delete(self, pull_request, user):
158 owner = user.user_id == pull_request.user_id
158 owner = user.user_id == pull_request.user_id
159 _perms = ('repository.admin',)
159 _perms = ('repository.admin',)
160 return self._check_perms(_perms, pull_request, user) or owner
160 return self._check_perms(_perms, pull_request, user) or owner
161
161
162 def check_user_change_status(self, pull_request, user, api=False):
162 def check_user_change_status(self, pull_request, user, api=False):
163 reviewer = user.user_id in [x.user_id for x in
163 reviewer = user.user_id in [x.user_id for x in
164 pull_request.reviewers]
164 pull_request.reviewers]
165 return self.check_user_update(pull_request, user, api) or reviewer
165 return self.check_user_update(pull_request, user, api) or reviewer
166
166
167 def check_user_comment(self, pull_request, user):
167 def check_user_comment(self, pull_request, user):
168 owner = user.user_id == pull_request.user_id
168 owner = user.user_id == pull_request.user_id
169 return self.check_user_read(pull_request, user) or owner
169 return self.check_user_read(pull_request, user) or owner
170
170
171 def get(self, pull_request):
171 def get(self, pull_request):
172 return self.__get_pull_request(pull_request)
172 return self.__get_pull_request(pull_request)
173
173
174 def _prepare_get_all_query(self, repo_name, source=False, statuses=None,
174 def _prepare_get_all_query(self, repo_name, source=False, statuses=None,
175 opened_by=None, order_by=None,
175 opened_by=None, order_by=None,
176 order_dir='desc'):
176 order_dir='desc'):
177 repo = None
177 repo = None
178 if repo_name:
178 if repo_name:
179 repo = self._get_repo(repo_name)
179 repo = self._get_repo(repo_name)
180
180
181 q = PullRequest.query()
181 q = PullRequest.query()
182
182
183 # source or target
183 # source or target
184 if repo and source:
184 if repo and source:
185 q = q.filter(PullRequest.source_repo == repo)
185 q = q.filter(PullRequest.source_repo == repo)
186 elif repo:
186 elif repo:
187 q = q.filter(PullRequest.target_repo == repo)
187 q = q.filter(PullRequest.target_repo == repo)
188
188
189 # closed,opened
189 # closed,opened
190 if statuses:
190 if statuses:
191 q = q.filter(PullRequest.status.in_(statuses))
191 q = q.filter(PullRequest.status.in_(statuses))
192
192
193 # opened by filter
193 # opened by filter
194 if opened_by:
194 if opened_by:
195 q = q.filter(PullRequest.user_id.in_(opened_by))
195 q = q.filter(PullRequest.user_id.in_(opened_by))
196
196
197 if order_by:
197 if order_by:
198 order_map = {
198 order_map = {
199 'name_raw': PullRequest.pull_request_id,
199 'name_raw': PullRequest.pull_request_id,
200 'title': PullRequest.title,
200 'title': PullRequest.title,
201 'updated_on_raw': PullRequest.updated_on,
201 'updated_on_raw': PullRequest.updated_on,
202 'target_repo': PullRequest.target_repo_id
202 'target_repo': PullRequest.target_repo_id
203 }
203 }
204 if order_dir == 'asc':
204 if order_dir == 'asc':
205 q = q.order_by(order_map[order_by].asc())
205 q = q.order_by(order_map[order_by].asc())
206 else:
206 else:
207 q = q.order_by(order_map[order_by].desc())
207 q = q.order_by(order_map[order_by].desc())
208
208
209 return q
209 return q
210
210
211 def count_all(self, repo_name, source=False, statuses=None,
211 def count_all(self, repo_name, source=False, statuses=None,
212 opened_by=None):
212 opened_by=None):
213 """
213 """
214 Count the number of pull requests for a specific repository.
214 Count the number of pull requests for a specific repository.
215
215
216 :param repo_name: target or source repo
216 :param repo_name: target or source repo
217 :param source: boolean flag to specify if repo_name refers to source
217 :param source: boolean flag to specify if repo_name refers to source
218 :param statuses: list of pull request statuses
218 :param statuses: list of pull request statuses
219 :param opened_by: author user of the pull request
219 :param opened_by: author user of the pull request
220 :returns: int number of pull requests
220 :returns: int number of pull requests
221 """
221 """
222 q = self._prepare_get_all_query(
222 q = self._prepare_get_all_query(
223 repo_name, source=source, statuses=statuses, opened_by=opened_by)
223 repo_name, source=source, statuses=statuses, opened_by=opened_by)
224
224
225 return q.count()
225 return q.count()
226
226
227 def get_all(self, repo_name, source=False, statuses=None, opened_by=None,
227 def get_all(self, repo_name, source=False, statuses=None, opened_by=None,
228 offset=0, length=None, order_by=None, order_dir='desc'):
228 offset=0, length=None, order_by=None, order_dir='desc'):
229 """
229 """
230 Get all pull requests for a specific repository.
230 Get all pull requests for a specific repository.
231
231
232 :param repo_name: target or source repo
232 :param repo_name: target or source repo
233 :param source: boolean flag to specify if repo_name refers to source
233 :param source: boolean flag to specify if repo_name refers to source
234 :param statuses: list of pull request statuses
234 :param statuses: list of pull request statuses
235 :param opened_by: author user of the pull request
235 :param opened_by: author user of the pull request
236 :param offset: pagination offset
236 :param offset: pagination offset
237 :param length: length of returned list
237 :param length: length of returned list
238 :param order_by: order of the returned list
238 :param order_by: order of the returned list
239 :param order_dir: 'asc' or 'desc' ordering direction
239 :param order_dir: 'asc' or 'desc' ordering direction
240 :returns: list of pull requests
240 :returns: list of pull requests
241 """
241 """
242 q = self._prepare_get_all_query(
242 q = self._prepare_get_all_query(
243 repo_name, source=source, statuses=statuses, opened_by=opened_by,
243 repo_name, source=source, statuses=statuses, opened_by=opened_by,
244 order_by=order_by, order_dir=order_dir)
244 order_by=order_by, order_dir=order_dir)
245
245
246 if length:
246 if length:
247 pull_requests = q.limit(length).offset(offset).all()
247 pull_requests = q.limit(length).offset(offset).all()
248 else:
248 else:
249 pull_requests = q.all()
249 pull_requests = q.all()
250
250
251 return pull_requests
251 return pull_requests
252
252
253 def count_awaiting_review(self, repo_name, source=False, statuses=None,
253 def count_awaiting_review(self, repo_name, source=False, statuses=None,
254 opened_by=None):
254 opened_by=None):
255 """
255 """
256 Count the number of pull requests for a specific repository that are
256 Count the number of pull requests for a specific repository that are
257 awaiting review.
257 awaiting review.
258
258
259 :param repo_name: target or source repo
259 :param repo_name: target or source repo
260 :param source: boolean flag to specify if repo_name refers to source
260 :param source: boolean flag to specify if repo_name refers to source
261 :param statuses: list of pull request statuses
261 :param statuses: list of pull request statuses
262 :param opened_by: author user of the pull request
262 :param opened_by: author user of the pull request
263 :returns: int number of pull requests
263 :returns: int number of pull requests
264 """
264 """
265 pull_requests = self.get_awaiting_review(
265 pull_requests = self.get_awaiting_review(
266 repo_name, source=source, statuses=statuses, opened_by=opened_by)
266 repo_name, source=source, statuses=statuses, opened_by=opened_by)
267
267
268 return len(pull_requests)
268 return len(pull_requests)
269
269
270 def get_awaiting_review(self, repo_name, source=False, statuses=None,
270 def get_awaiting_review(self, repo_name, source=False, statuses=None,
271 opened_by=None, offset=0, length=None,
271 opened_by=None, offset=0, length=None,
272 order_by=None, order_dir='desc'):
272 order_by=None, order_dir='desc'):
273 """
273 """
274 Get all pull requests for a specific repository that are awaiting
274 Get all pull requests for a specific repository that are awaiting
275 review.
275 review.
276
276
277 :param repo_name: target or source repo
277 :param repo_name: target or source repo
278 :param source: boolean flag to specify if repo_name refers to source
278 :param source: boolean flag to specify if repo_name refers to source
279 :param statuses: list of pull request statuses
279 :param statuses: list of pull request statuses
280 :param opened_by: author user of the pull request
280 :param opened_by: author user of the pull request
281 :param offset: pagination offset
281 :param offset: pagination offset
282 :param length: length of returned list
282 :param length: length of returned list
283 :param order_by: order of the returned list
283 :param order_by: order of the returned list
284 :param order_dir: 'asc' or 'desc' ordering direction
284 :param order_dir: 'asc' or 'desc' ordering direction
285 :returns: list of pull requests
285 :returns: list of pull requests
286 """
286 """
287 pull_requests = self.get_all(
287 pull_requests = self.get_all(
288 repo_name, source=source, statuses=statuses, opened_by=opened_by,
288 repo_name, source=source, statuses=statuses, opened_by=opened_by,
289 order_by=order_by, order_dir=order_dir)
289 order_by=order_by, order_dir=order_dir)
290
290
291 _filtered_pull_requests = []
291 _filtered_pull_requests = []
292 for pr in pull_requests:
292 for pr in pull_requests:
293 status = pr.calculated_review_status()
293 status = pr.calculated_review_status()
294 if status in [ChangesetStatus.STATUS_NOT_REVIEWED,
294 if status in [ChangesetStatus.STATUS_NOT_REVIEWED,
295 ChangesetStatus.STATUS_UNDER_REVIEW]:
295 ChangesetStatus.STATUS_UNDER_REVIEW]:
296 _filtered_pull_requests.append(pr)
296 _filtered_pull_requests.append(pr)
297 if length:
297 if length:
298 return _filtered_pull_requests[offset:offset+length]
298 return _filtered_pull_requests[offset:offset+length]
299 else:
299 else:
300 return _filtered_pull_requests
300 return _filtered_pull_requests
301
301
302 def count_awaiting_my_review(self, repo_name, source=False, statuses=None,
302 def count_awaiting_my_review(self, repo_name, source=False, statuses=None,
303 opened_by=None, user_id=None):
303 opened_by=None, user_id=None):
304 """
304 """
305 Count the number of pull requests for a specific repository that are
305 Count the number of pull requests for a specific repository that are
306 awaiting review from a specific user.
306 awaiting review from a specific user.
307
307
308 :param repo_name: target or source repo
308 :param repo_name: target or source repo
309 :param source: boolean flag to specify if repo_name refers to source
309 :param source: boolean flag to specify if repo_name refers to source
310 :param statuses: list of pull request statuses
310 :param statuses: list of pull request statuses
311 :param opened_by: author user of the pull request
311 :param opened_by: author user of the pull request
312 :param user_id: reviewer user of the pull request
312 :param user_id: reviewer user of the pull request
313 :returns: int number of pull requests
313 :returns: int number of pull requests
314 """
314 """
315 pull_requests = self.get_awaiting_my_review(
315 pull_requests = self.get_awaiting_my_review(
316 repo_name, source=source, statuses=statuses, opened_by=opened_by,
316 repo_name, source=source, statuses=statuses, opened_by=opened_by,
317 user_id=user_id)
317 user_id=user_id)
318
318
319 return len(pull_requests)
319 return len(pull_requests)
320
320
321 def get_awaiting_my_review(self, repo_name, source=False, statuses=None,
321 def get_awaiting_my_review(self, repo_name, source=False, statuses=None,
322 opened_by=None, user_id=None, offset=0,
322 opened_by=None, user_id=None, offset=0,
323 length=None, order_by=None, order_dir='desc'):
323 length=None, order_by=None, order_dir='desc'):
324 """
324 """
325 Get all pull requests for a specific repository that are awaiting
325 Get all pull requests for a specific repository that are awaiting
326 review from a specific user.
326 review from a specific user.
327
327
328 :param repo_name: target or source repo
328 :param repo_name: target or source repo
329 :param source: boolean flag to specify if repo_name refers to source
329 :param source: boolean flag to specify if repo_name refers to source
330 :param statuses: list of pull request statuses
330 :param statuses: list of pull request statuses
331 :param opened_by: author user of the pull request
331 :param opened_by: author user of the pull request
332 :param user_id: reviewer user of the pull request
332 :param user_id: reviewer user of the pull request
333 :param offset: pagination offset
333 :param offset: pagination offset
334 :param length: length of returned list
334 :param length: length of returned list
335 :param order_by: order of the returned list
335 :param order_by: order of the returned list
336 :param order_dir: 'asc' or 'desc' ordering direction
336 :param order_dir: 'asc' or 'desc' ordering direction
337 :returns: list of pull requests
337 :returns: list of pull requests
338 """
338 """
339 pull_requests = self.get_all(
339 pull_requests = self.get_all(
340 repo_name, source=source, statuses=statuses, opened_by=opened_by,
340 repo_name, source=source, statuses=statuses, opened_by=opened_by,
341 order_by=order_by, order_dir=order_dir)
341 order_by=order_by, order_dir=order_dir)
342
342
343 _my = PullRequestModel().get_not_reviewed(user_id)
343 _my = PullRequestModel().get_not_reviewed(user_id)
344 my_participation = []
344 my_participation = []
345 for pr in pull_requests:
345 for pr in pull_requests:
346 if pr in _my:
346 if pr in _my:
347 my_participation.append(pr)
347 my_participation.append(pr)
348 _filtered_pull_requests = my_participation
348 _filtered_pull_requests = my_participation
349 if length:
349 if length:
350 return _filtered_pull_requests[offset:offset+length]
350 return _filtered_pull_requests[offset:offset+length]
351 else:
351 else:
352 return _filtered_pull_requests
352 return _filtered_pull_requests
353
353
354 def get_not_reviewed(self, user_id):
354 def get_not_reviewed(self, user_id):
355 return [
355 return [
356 x.pull_request for x in PullRequestReviewers.query().filter(
356 x.pull_request for x in PullRequestReviewers.query().filter(
357 PullRequestReviewers.user_id == user_id).all()
357 PullRequestReviewers.user_id == user_id).all()
358 ]
358 ]
359
359
360 def _prepare_participating_query(self, user_id=None, statuses=None,
360 def _prepare_participating_query(self, user_id=None, statuses=None,
361 order_by=None, order_dir='desc'):
361 order_by=None, order_dir='desc'):
362 q = PullRequest.query()
362 q = PullRequest.query()
363 if user_id:
363 if user_id:
364 reviewers_subquery = Session().query(
364 reviewers_subquery = Session().query(
365 PullRequestReviewers.pull_request_id).filter(
365 PullRequestReviewers.pull_request_id).filter(
366 PullRequestReviewers.user_id == user_id).subquery()
366 PullRequestReviewers.user_id == user_id).subquery()
367 user_filter = or_(
367 user_filter = or_(
368 PullRequest.user_id == user_id,
368 PullRequest.user_id == user_id,
369 PullRequest.pull_request_id.in_(reviewers_subquery)
369 PullRequest.pull_request_id.in_(reviewers_subquery)
370 )
370 )
371 q = PullRequest.query().filter(user_filter)
371 q = PullRequest.query().filter(user_filter)
372
372
373 # closed,opened
373 # closed,opened
374 if statuses:
374 if statuses:
375 q = q.filter(PullRequest.status.in_(statuses))
375 q = q.filter(PullRequest.status.in_(statuses))
376
376
377 if order_by:
377 if order_by:
378 order_map = {
378 order_map = {
379 'name_raw': PullRequest.pull_request_id,
379 'name_raw': PullRequest.pull_request_id,
380 'title': PullRequest.title,
380 'title': PullRequest.title,
381 'updated_on_raw': PullRequest.updated_on,
381 'updated_on_raw': PullRequest.updated_on,
382 'target_repo': PullRequest.target_repo_id
382 'target_repo': PullRequest.target_repo_id
383 }
383 }
384 if order_dir == 'asc':
384 if order_dir == 'asc':
385 q = q.order_by(order_map[order_by].asc())
385 q = q.order_by(order_map[order_by].asc())
386 else:
386 else:
387 q = q.order_by(order_map[order_by].desc())
387 q = q.order_by(order_map[order_by].desc())
388
388
389 return q
389 return q
390
390
391 def count_im_participating_in(self, user_id=None, statuses=None):
391 def count_im_participating_in(self, user_id=None, statuses=None):
392 q = self._prepare_participating_query(user_id, statuses=statuses)
392 q = self._prepare_participating_query(user_id, statuses=statuses)
393 return q.count()
393 return q.count()
394
394
395 def get_im_participating_in(
395 def get_im_participating_in(
396 self, user_id=None, statuses=None, offset=0,
396 self, user_id=None, statuses=None, offset=0,
397 length=None, order_by=None, order_dir='desc'):
397 length=None, order_by=None, order_dir='desc'):
398 """
398 """
399 Get all Pull requests that i'm participating in, or i have opened
399 Get all Pull requests that i'm participating in, or i have opened
400 """
400 """
401
401
402 q = self._prepare_participating_query(
402 q = self._prepare_participating_query(
403 user_id, statuses=statuses, order_by=order_by,
403 user_id, statuses=statuses, order_by=order_by,
404 order_dir=order_dir)
404 order_dir=order_dir)
405
405
406 if length:
406 if length:
407 pull_requests = q.limit(length).offset(offset).all()
407 pull_requests = q.limit(length).offset(offset).all()
408 else:
408 else:
409 pull_requests = q.all()
409 pull_requests = q.all()
410
410
411 return pull_requests
411 return pull_requests
412
412
413 def get_versions(self, pull_request):
413 def get_versions(self, pull_request):
414 """
414 """
415 returns version of pull request sorted by ID descending
415 returns version of pull request sorted by ID descending
416 """
416 """
417 return PullRequestVersion.query()\
417 return PullRequestVersion.query()\
418 .filter(PullRequestVersion.pull_request == pull_request)\
418 .filter(PullRequestVersion.pull_request == pull_request)\
419 .order_by(PullRequestVersion.pull_request_version_id.asc())\
419 .order_by(PullRequestVersion.pull_request_version_id.asc())\
420 .all()
420 .all()
421
421
422 def get_pr_version(self, pull_request_id, version=None):
422 def get_pr_version(self, pull_request_id, version=None):
423 at_version = None
423 at_version = None
424
424
425 if version and version == 'latest':
425 if version and version == 'latest':
426 pull_request_ver = PullRequest.get(pull_request_id)
426 pull_request_ver = PullRequest.get(pull_request_id)
427 pull_request_obj = pull_request_ver
427 pull_request_obj = pull_request_ver
428 _org_pull_request_obj = pull_request_obj
428 _org_pull_request_obj = pull_request_obj
429 at_version = 'latest'
429 at_version = 'latest'
430 elif version:
430 elif version:
431 pull_request_ver = PullRequestVersion.get_or_404(version)
431 pull_request_ver = PullRequestVersion.get_or_404(version)
432 pull_request_obj = pull_request_ver
432 pull_request_obj = pull_request_ver
433 _org_pull_request_obj = pull_request_ver.pull_request
433 _org_pull_request_obj = pull_request_ver.pull_request
434 at_version = pull_request_ver.pull_request_version_id
434 at_version = pull_request_ver.pull_request_version_id
435 else:
435 else:
436 _org_pull_request_obj = pull_request_obj = PullRequest.get_or_404(
436 _org_pull_request_obj = pull_request_obj = PullRequest.get_or_404(
437 pull_request_id)
437 pull_request_id)
438
438
439 pull_request_display_obj = PullRequest.get_pr_display_object(
439 pull_request_display_obj = PullRequest.get_pr_display_object(
440 pull_request_obj, _org_pull_request_obj)
440 pull_request_obj, _org_pull_request_obj)
441
441
442 return _org_pull_request_obj, pull_request_obj, \
442 return _org_pull_request_obj, pull_request_obj, \
443 pull_request_display_obj, at_version
443 pull_request_display_obj, at_version
444
444
445 def create(self, created_by, source_repo, source_ref, target_repo,
445 def create(self, created_by, source_repo, source_ref, target_repo,
446 target_ref, revisions, reviewers, title, description=None,
446 target_ref, revisions, reviewers, title, description=None,
447 description_renderer=None,
447 description_renderer=None,
448 reviewer_data=None, translator=None, auth_user=None):
448 reviewer_data=None, translator=None, auth_user=None):
449 translator = translator or get_current_request().translate
449 translator = translator or get_current_request().translate
450
450
451 created_by_user = self._get_user(created_by)
451 created_by_user = self._get_user(created_by)
452 auth_user = auth_user or created_by_user.AuthUser()
452 auth_user = auth_user or created_by_user.AuthUser()
453 source_repo = self._get_repo(source_repo)
453 source_repo = self._get_repo(source_repo)
454 target_repo = self._get_repo(target_repo)
454 target_repo = self._get_repo(target_repo)
455
455
456 pull_request = PullRequest()
456 pull_request = PullRequest()
457 pull_request.source_repo = source_repo
457 pull_request.source_repo = source_repo
458 pull_request.source_ref = source_ref
458 pull_request.source_ref = source_ref
459 pull_request.target_repo = target_repo
459 pull_request.target_repo = target_repo
460 pull_request.target_ref = target_ref
460 pull_request.target_ref = target_ref
461 pull_request.revisions = revisions
461 pull_request.revisions = revisions
462 pull_request.title = title
462 pull_request.title = title
463 pull_request.description = description
463 pull_request.description = description
464 pull_request.description_renderer = description_renderer
464 pull_request.description_renderer = description_renderer
465 pull_request.author = created_by_user
465 pull_request.author = created_by_user
466 pull_request.reviewer_data = reviewer_data
466 pull_request.reviewer_data = reviewer_data
467
467
468 Session().add(pull_request)
468 Session().add(pull_request)
469 Session().flush()
469 Session().flush()
470
470
471 reviewer_ids = set()
471 reviewer_ids = set()
472 # members / reviewers
472 # members / reviewers
473 for reviewer_object in reviewers:
473 for reviewer_object in reviewers:
474 user_id, reasons, mandatory, rules = reviewer_object
474 user_id, reasons, mandatory, rules = reviewer_object
475 user = self._get_user(user_id)
475 user = self._get_user(user_id)
476
476
477 # skip duplicates
477 # skip duplicates
478 if user.user_id in reviewer_ids:
478 if user.user_id in reviewer_ids:
479 continue
479 continue
480
480
481 reviewer_ids.add(user.user_id)
481 reviewer_ids.add(user.user_id)
482
482
483 reviewer = PullRequestReviewers()
483 reviewer = PullRequestReviewers()
484 reviewer.user = user
484 reviewer.user = user
485 reviewer.pull_request = pull_request
485 reviewer.pull_request = pull_request
486 reviewer.reasons = reasons
486 reviewer.reasons = reasons
487 reviewer.mandatory = mandatory
487 reviewer.mandatory = mandatory
488
488
489 # NOTE(marcink): pick only first rule for now
489 # NOTE(marcink): pick only first rule for now
490 rule_id = list(rules)[0] if rules else None
490 rule_id = list(rules)[0] if rules else None
491 rule = RepoReviewRule.get(rule_id) if rule_id else None
491 rule = RepoReviewRule.get(rule_id) if rule_id else None
492 if rule:
492 if rule:
493 review_group = rule.user_group_vote_rule(user_id)
493 review_group = rule.user_group_vote_rule(user_id)
494 # we check if this particular reviewer is member of a voting group
494 # we check if this particular reviewer is member of a voting group
495 if review_group:
495 if review_group:
496 # NOTE(marcink):
496 # NOTE(marcink):
497 # can be that user is member of more but we pick the first same,
497 # can be that user is member of more but we pick the first same,
498 # same as default reviewers algo
498 # same as default reviewers algo
499 review_group = review_group[0]
499 review_group = review_group[0]
500
500
501 rule_data = {
501 rule_data = {
502 'rule_name':
502 'rule_name':
503 rule.review_rule_name,
503 rule.review_rule_name,
504 'rule_user_group_entry_id':
504 'rule_user_group_entry_id':
505 review_group.repo_review_rule_users_group_id,
505 review_group.repo_review_rule_users_group_id,
506 'rule_user_group_name':
506 'rule_user_group_name':
507 review_group.users_group.users_group_name,
507 review_group.users_group.users_group_name,
508 'rule_user_group_members':
508 'rule_user_group_members':
509 [x.user.username for x in review_group.users_group.members],
509 [x.user.username for x in review_group.users_group.members],
510 'rule_user_group_members_id':
510 'rule_user_group_members_id':
511 [x.user.user_id for x in review_group.users_group.members],
511 [x.user.user_id for x in review_group.users_group.members],
512 }
512 }
513 # e.g {'vote_rule': -1, 'mandatory': True}
513 # e.g {'vote_rule': -1, 'mandatory': True}
514 rule_data.update(review_group.rule_data())
514 rule_data.update(review_group.rule_data())
515
515
516 reviewer.rule_data = rule_data
516 reviewer.rule_data = rule_data
517
517
518 Session().add(reviewer)
518 Session().add(reviewer)
519 Session().flush()
519 Session().flush()
520
520
521 # Set approval status to "Under Review" for all commits which are
521 # Set approval status to "Under Review" for all commits which are
522 # part of this pull request.
522 # part of this pull request.
523 ChangesetStatusModel().set_status(
523 ChangesetStatusModel().set_status(
524 repo=target_repo,
524 repo=target_repo,
525 status=ChangesetStatus.STATUS_UNDER_REVIEW,
525 status=ChangesetStatus.STATUS_UNDER_REVIEW,
526 user=created_by_user,
526 user=created_by_user,
527 pull_request=pull_request
527 pull_request=pull_request
528 )
528 )
529 # we commit early at this point. This has to do with a fact
529 # we commit early at this point. This has to do with a fact
530 # that before queries do some row-locking. And because of that
530 # that before queries do some row-locking. And because of that
531 # we need to commit and finish transation before below validate call
531 # we need to commit and finish transation before below validate call
532 # that for large repos could be long resulting in long row locks
532 # that for large repos could be long resulting in long row locks
533 Session().commit()
533 Session().commit()
534
534
535 # prepare workspace, and run initial merge simulation
535 # prepare workspace, and run initial merge simulation
536 MergeCheck.validate(
536 MergeCheck.validate(
537 pull_request, auth_user=auth_user, translator=translator)
537 pull_request, auth_user=auth_user, translator=translator)
538
538
539 self.notify_reviewers(pull_request, reviewer_ids)
539 self.notify_reviewers(pull_request, reviewer_ids)
540 self._trigger_pull_request_hook(
540 self._trigger_pull_request_hook(
541 pull_request, created_by_user, 'create')
541 pull_request, created_by_user, 'create')
542
542
543 creation_data = pull_request.get_api_data(with_merge_state=False)
543 creation_data = pull_request.get_api_data(with_merge_state=False)
544 self._log_audit_action(
544 self._log_audit_action(
545 'repo.pull_request.create', {'data': creation_data},
545 'repo.pull_request.create', {'data': creation_data},
546 auth_user, pull_request)
546 auth_user, pull_request)
547
547
548 return pull_request
548 return pull_request
549
549
550 def _trigger_pull_request_hook(self, pull_request, user, action):
550 def _trigger_pull_request_hook(self, pull_request, user, action):
551 pull_request = self.__get_pull_request(pull_request)
551 pull_request = self.__get_pull_request(pull_request)
552 target_scm = pull_request.target_repo.scm_instance()
552 target_scm = pull_request.target_repo.scm_instance()
553 if action == 'create':
553 if action == 'create':
554 trigger_hook = hooks_utils.trigger_log_create_pull_request_hook
554 trigger_hook = hooks_utils.trigger_log_create_pull_request_hook
555 elif action == 'merge':
555 elif action == 'merge':
556 trigger_hook = hooks_utils.trigger_log_merge_pull_request_hook
556 trigger_hook = hooks_utils.trigger_log_merge_pull_request_hook
557 elif action == 'close':
557 elif action == 'close':
558 trigger_hook = hooks_utils.trigger_log_close_pull_request_hook
558 trigger_hook = hooks_utils.trigger_log_close_pull_request_hook
559 elif action == 'review_status_change':
559 elif action == 'review_status_change':
560 trigger_hook = hooks_utils.trigger_log_review_pull_request_hook
560 trigger_hook = hooks_utils.trigger_log_review_pull_request_hook
561 elif action == 'update':
561 elif action == 'update':
562 trigger_hook = hooks_utils.trigger_log_update_pull_request_hook
562 trigger_hook = hooks_utils.trigger_log_update_pull_request_hook
563 else:
563 else:
564 return
564 return
565
565
566 trigger_hook(
566 trigger_hook(
567 username=user.username,
567 username=user.username,
568 repo_name=pull_request.target_repo.repo_name,
568 repo_name=pull_request.target_repo.repo_name,
569 repo_alias=target_scm.alias,
569 repo_alias=target_scm.alias,
570 pull_request=pull_request)
570 pull_request=pull_request)
571
571
572 def _get_commit_ids(self, pull_request):
572 def _get_commit_ids(self, pull_request):
573 """
573 """
574 Return the commit ids of the merged pull request.
574 Return the commit ids of the merged pull request.
575
575
576 This method is not dealing correctly yet with the lack of autoupdates
576 This method is not dealing correctly yet with the lack of autoupdates
577 nor with the implicit target updates.
577 nor with the implicit target updates.
578 For example: if a commit in the source repo is already in the target it
578 For example: if a commit in the source repo is already in the target it
579 will be reported anyways.
579 will be reported anyways.
580 """
580 """
581 merge_rev = pull_request.merge_rev
581 merge_rev = pull_request.merge_rev
582 if merge_rev is None:
582 if merge_rev is None:
583 raise ValueError('This pull request was not merged yet')
583 raise ValueError('This pull request was not merged yet')
584
584
585 commit_ids = list(pull_request.revisions)
585 commit_ids = list(pull_request.revisions)
586 if merge_rev not in commit_ids:
586 if merge_rev not in commit_ids:
587 commit_ids.append(merge_rev)
587 commit_ids.append(merge_rev)
588
588
589 return commit_ids
589 return commit_ids
590
590
591 def merge_repo(self, pull_request, user, extras):
591 def merge_repo(self, pull_request, user, extras):
592 log.debug("Merging pull request %s", pull_request.pull_request_id)
592 log.debug("Merging pull request %s", pull_request.pull_request_id)
593 extras['user_agent'] = 'internal-merge'
593 extras['user_agent'] = 'internal-merge'
594 merge_state = self._merge_pull_request(pull_request, user, extras)
594 merge_state = self._merge_pull_request(pull_request, user, extras)
595 if merge_state.executed:
595 if merge_state.executed:
596 log.debug(
596 log.debug(
597 "Merge was successful, updating the pull request comments.")
597 "Merge was successful, updating the pull request comments.")
598 self._comment_and_close_pr(pull_request, user, merge_state)
598 self._comment_and_close_pr(pull_request, user, merge_state)
599
599
600 self._log_audit_action(
600 self._log_audit_action(
601 'repo.pull_request.merge',
601 'repo.pull_request.merge',
602 {'merge_state': merge_state.__dict__},
602 {'merge_state': merge_state.__dict__},
603 user, pull_request)
603 user, pull_request)
604
604
605 else:
605 else:
606 log.warn("Merge failed, not updating the pull request.")
606 log.warn("Merge failed, not updating the pull request.")
607 return merge_state
607 return merge_state
608
608
609 def _merge_pull_request(self, pull_request, user, extras, merge_msg=None):
609 def _merge_pull_request(self, pull_request, user, extras, merge_msg=None):
610 target_vcs = pull_request.target_repo.scm_instance()
610 target_vcs = pull_request.target_repo.scm_instance()
611 source_vcs = pull_request.source_repo.scm_instance()
611 source_vcs = pull_request.source_repo.scm_instance()
612
612
613 message = safe_unicode(merge_msg or vcs_settings.MERGE_MESSAGE_TMPL).format(
613 message = safe_unicode(merge_msg or vcs_settings.MERGE_MESSAGE_TMPL).format(
614 pr_id=pull_request.pull_request_id,
614 pr_id=pull_request.pull_request_id,
615 pr_title=pull_request.title,
615 pr_title=pull_request.title,
616 source_repo=source_vcs.name,
616 source_repo=source_vcs.name,
617 source_ref_name=pull_request.source_ref_parts.name,
617 source_ref_name=pull_request.source_ref_parts.name,
618 target_repo=target_vcs.name,
618 target_repo=target_vcs.name,
619 target_ref_name=pull_request.target_ref_parts.name,
619 target_ref_name=pull_request.target_ref_parts.name,
620 )
620 )
621
621
622 workspace_id = self._workspace_id(pull_request)
622 workspace_id = self._workspace_id(pull_request)
623 repo_id = pull_request.target_repo.repo_id
623 repo_id = pull_request.target_repo.repo_id
624 use_rebase = self._use_rebase_for_merging(pull_request)
624 use_rebase = self._use_rebase_for_merging(pull_request)
625 close_branch = self._close_branch_before_merging(pull_request)
625 close_branch = self._close_branch_before_merging(pull_request)
626
626
627 target_ref = self._refresh_reference(
627 target_ref = self._refresh_reference(
628 pull_request.target_ref_parts, target_vcs)
628 pull_request.target_ref_parts, target_vcs)
629
629
630 callback_daemon, extras = prepare_callback_daemon(
630 callback_daemon, extras = prepare_callback_daemon(
631 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
631 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
632 host=vcs_settings.HOOKS_HOST,
632 host=vcs_settings.HOOKS_HOST,
633 use_direct_calls=vcs_settings.HOOKS_DIRECT_CALLS)
633 use_direct_calls=vcs_settings.HOOKS_DIRECT_CALLS)
634
634
635 with callback_daemon:
635 with callback_daemon:
636 # TODO: johbo: Implement a clean way to run a config_override
636 # TODO: johbo: Implement a clean way to run a config_override
637 # for a single call.
637 # for a single call.
638 target_vcs.config.set(
638 target_vcs.config.set(
639 'rhodecode', 'RC_SCM_DATA', json.dumps(extras))
639 'rhodecode', 'RC_SCM_DATA', json.dumps(extras))
640
640
641 user_name = user.short_contact
641 user_name = user.short_contact
642 merge_state = target_vcs.merge(
642 merge_state = target_vcs.merge(
643 repo_id, workspace_id, target_ref, source_vcs,
643 repo_id, workspace_id, target_ref, source_vcs,
644 pull_request.source_ref_parts,
644 pull_request.source_ref_parts,
645 user_name=user_name, user_email=user.email,
645 user_name=user_name, user_email=user.email,
646 message=message, use_rebase=use_rebase,
646 message=message, use_rebase=use_rebase,
647 close_branch=close_branch)
647 close_branch=close_branch)
648 return merge_state
648 return merge_state
649
649
650 def _comment_and_close_pr(self, pull_request, user, merge_state, close_msg=None):
650 def _comment_and_close_pr(self, pull_request, user, merge_state, close_msg=None):
651 pull_request.merge_rev = merge_state.merge_ref.commit_id
651 pull_request.merge_rev = merge_state.merge_ref.commit_id
652 pull_request.updated_on = datetime.datetime.now()
652 pull_request.updated_on = datetime.datetime.now()
653 close_msg = close_msg or 'Pull request merged and closed'
653 close_msg = close_msg or 'Pull request merged and closed'
654
654
655 CommentsModel().create(
655 CommentsModel().create(
656 text=safe_unicode(close_msg),
656 text=safe_unicode(close_msg),
657 repo=pull_request.target_repo.repo_id,
657 repo=pull_request.target_repo.repo_id,
658 user=user.user_id,
658 user=user.user_id,
659 pull_request=pull_request.pull_request_id,
659 pull_request=pull_request.pull_request_id,
660 f_path=None,
660 f_path=None,
661 line_no=None,
661 line_no=None,
662 closing_pr=True
662 closing_pr=True
663 )
663 )
664
664
665 Session().add(pull_request)
665 Session().add(pull_request)
666 Session().flush()
666 Session().flush()
667 # TODO: paris: replace invalidation with less radical solution
667 # TODO: paris: replace invalidation with less radical solution
668 ScmModel().mark_for_invalidation(
668 ScmModel().mark_for_invalidation(
669 pull_request.target_repo.repo_name)
669 pull_request.target_repo.repo_name)
670 self._trigger_pull_request_hook(pull_request, user, 'merge')
670 self._trigger_pull_request_hook(pull_request, user, 'merge')
671
671
672 def has_valid_update_type(self, pull_request):
672 def has_valid_update_type(self, pull_request):
673 source_ref_type = pull_request.source_ref_parts.type
673 source_ref_type = pull_request.source_ref_parts.type
674 return source_ref_type in ['book', 'branch', 'tag']
674 return source_ref_type in ['book', 'branch', 'tag']
675
675
676 def update_commits(self, pull_request):
676 def update_commits(self, pull_request):
677 """
677 """
678 Get the updated list of commits for the pull request
678 Get the updated list of commits for the pull request
679 and return the new pull request version and the list
679 and return the new pull request version and the list
680 of commits processed by this update action
680 of commits processed by this update action
681 """
681 """
682 pull_request = self.__get_pull_request(pull_request)
682 pull_request = self.__get_pull_request(pull_request)
683 source_ref_type = pull_request.source_ref_parts.type
683 source_ref_type = pull_request.source_ref_parts.type
684 source_ref_name = pull_request.source_ref_parts.name
684 source_ref_name = pull_request.source_ref_parts.name
685 source_ref_id = pull_request.source_ref_parts.commit_id
685 source_ref_id = pull_request.source_ref_parts.commit_id
686
686
687 target_ref_type = pull_request.target_ref_parts.type
687 target_ref_type = pull_request.target_ref_parts.type
688 target_ref_name = pull_request.target_ref_parts.name
688 target_ref_name = pull_request.target_ref_parts.name
689 target_ref_id = pull_request.target_ref_parts.commit_id
689 target_ref_id = pull_request.target_ref_parts.commit_id
690
690
691 if not self.has_valid_update_type(pull_request):
691 if not self.has_valid_update_type(pull_request):
692 log.debug(
692 log.debug(
693 "Skipping update of pull request %s due to ref type: %s",
693 "Skipping update of pull request %s due to ref type: %s",
694 pull_request, source_ref_type)
694 pull_request, source_ref_type)
695 return UpdateResponse(
695 return UpdateResponse(
696 executed=False,
696 executed=False,
697 reason=UpdateFailureReason.WRONG_REF_TYPE,
697 reason=UpdateFailureReason.WRONG_REF_TYPE,
698 old=pull_request, new=None, changes=None,
698 old=pull_request, new=None, changes=None,
699 source_changed=False, target_changed=False)
699 source_changed=False, target_changed=False)
700
700
701 # source repo
701 # source repo
702 source_repo = pull_request.source_repo.scm_instance()
702 source_repo = pull_request.source_repo.scm_instance()
703 try:
703 try:
704 source_commit = source_repo.get_commit(commit_id=source_ref_name)
704 source_commit = source_repo.get_commit(commit_id=source_ref_name)
705 except CommitDoesNotExistError:
705 except CommitDoesNotExistError:
706 return UpdateResponse(
706 return UpdateResponse(
707 executed=False,
707 executed=False,
708 reason=UpdateFailureReason.MISSING_SOURCE_REF,
708 reason=UpdateFailureReason.MISSING_SOURCE_REF,
709 old=pull_request, new=None, changes=None,
709 old=pull_request, new=None, changes=None,
710 source_changed=False, target_changed=False)
710 source_changed=False, target_changed=False)
711
711
712 source_changed = source_ref_id != source_commit.raw_id
712 source_changed = source_ref_id != source_commit.raw_id
713
713
714 # target repo
714 # target repo
715 target_repo = pull_request.target_repo.scm_instance()
715 target_repo = pull_request.target_repo.scm_instance()
716 try:
716 try:
717 target_commit = target_repo.get_commit(commit_id=target_ref_name)
717 target_commit = target_repo.get_commit(commit_id=target_ref_name)
718 except CommitDoesNotExistError:
718 except CommitDoesNotExistError:
719 return UpdateResponse(
719 return UpdateResponse(
720 executed=False,
720 executed=False,
721 reason=UpdateFailureReason.MISSING_TARGET_REF,
721 reason=UpdateFailureReason.MISSING_TARGET_REF,
722 old=pull_request, new=None, changes=None,
722 old=pull_request, new=None, changes=None,
723 source_changed=False, target_changed=False)
723 source_changed=False, target_changed=False)
724 target_changed = target_ref_id != target_commit.raw_id
724 target_changed = target_ref_id != target_commit.raw_id
725
725
726 if not (source_changed or target_changed):
726 if not (source_changed or target_changed):
727 log.debug("Nothing changed in pull request %s", pull_request)
727 log.debug("Nothing changed in pull request %s", pull_request)
728 return UpdateResponse(
728 return UpdateResponse(
729 executed=False,
729 executed=False,
730 reason=UpdateFailureReason.NO_CHANGE,
730 reason=UpdateFailureReason.NO_CHANGE,
731 old=pull_request, new=None, changes=None,
731 old=pull_request, new=None, changes=None,
732 source_changed=target_changed, target_changed=source_changed)
732 source_changed=target_changed, target_changed=source_changed)
733
733
734 change_in_found = 'target repo' if target_changed else 'source repo'
734 change_in_found = 'target repo' if target_changed else 'source repo'
735 log.debug('Updating pull request because of change in %s detected',
735 log.debug('Updating pull request because of change in %s detected',
736 change_in_found)
736 change_in_found)
737
737
738 # Finally there is a need for an update, in case of source change
738 # Finally there is a need for an update, in case of source change
739 # we create a new version, else just an update
739 # we create a new version, else just an update
740 if source_changed:
740 if source_changed:
741 pull_request_version = self._create_version_from_snapshot(pull_request)
741 pull_request_version = self._create_version_from_snapshot(pull_request)
742 self._link_comments_to_version(pull_request_version)
742 self._link_comments_to_version(pull_request_version)
743 else:
743 else:
744 try:
744 try:
745 ver = pull_request.versions[-1]
745 ver = pull_request.versions[-1]
746 except IndexError:
746 except IndexError:
747 ver = None
747 ver = None
748
748
749 pull_request.pull_request_version_id = \
749 pull_request.pull_request_version_id = \
750 ver.pull_request_version_id if ver else None
750 ver.pull_request_version_id if ver else None
751 pull_request_version = pull_request
751 pull_request_version = pull_request
752
752
753 try:
753 try:
754 if target_ref_type in ('tag', 'branch', 'book'):
754 if target_ref_type in ('tag', 'branch', 'book'):
755 target_commit = target_repo.get_commit(target_ref_name)
755 target_commit = target_repo.get_commit(target_ref_name)
756 else:
756 else:
757 target_commit = target_repo.get_commit(target_ref_id)
757 target_commit = target_repo.get_commit(target_ref_id)
758 except CommitDoesNotExistError:
758 except CommitDoesNotExistError:
759 return UpdateResponse(
759 return UpdateResponse(
760 executed=False,
760 executed=False,
761 reason=UpdateFailureReason.MISSING_TARGET_REF,
761 reason=UpdateFailureReason.MISSING_TARGET_REF,
762 old=pull_request, new=None, changes=None,
762 old=pull_request, new=None, changes=None,
763 source_changed=source_changed, target_changed=target_changed)
763 source_changed=source_changed, target_changed=target_changed)
764
764
765 # re-compute commit ids
765 # re-compute commit ids
766 old_commit_ids = pull_request.revisions
766 old_commit_ids = pull_request.revisions
767 pre_load = ["author", "branch", "date", "message"]
767 pre_load = ["author", "branch", "date", "message"]
768 commit_ranges = target_repo.compare(
768 commit_ranges = target_repo.compare(
769 target_commit.raw_id, source_commit.raw_id, source_repo, merge=True,
769 target_commit.raw_id, source_commit.raw_id, source_repo, merge=True,
770 pre_load=pre_load)
770 pre_load=pre_load)
771
771
772 ancestor = target_repo.get_common_ancestor(
772 ancestor = target_repo.get_common_ancestor(
773 target_commit.raw_id, source_commit.raw_id, source_repo)
773 target_commit.raw_id, source_commit.raw_id, source_repo)
774
774
775 pull_request.source_ref = '%s:%s:%s' % (
775 pull_request.source_ref = '%s:%s:%s' % (
776 source_ref_type, source_ref_name, source_commit.raw_id)
776 source_ref_type, source_ref_name, source_commit.raw_id)
777 pull_request.target_ref = '%s:%s:%s' % (
777 pull_request.target_ref = '%s:%s:%s' % (
778 target_ref_type, target_ref_name, ancestor)
778 target_ref_type, target_ref_name, ancestor)
779
779
780 pull_request.revisions = [
780 pull_request.revisions = [
781 commit.raw_id for commit in reversed(commit_ranges)]
781 commit.raw_id for commit in reversed(commit_ranges)]
782 pull_request.updated_on = datetime.datetime.now()
782 pull_request.updated_on = datetime.datetime.now()
783 Session().add(pull_request)
783 Session().add(pull_request)
784 new_commit_ids = pull_request.revisions
784 new_commit_ids = pull_request.revisions
785
785
786 old_diff_data, new_diff_data = self._generate_update_diffs(
786 old_diff_data, new_diff_data = self._generate_update_diffs(
787 pull_request, pull_request_version)
787 pull_request, pull_request_version)
788
788
789 # calculate commit and file changes
789 # calculate commit and file changes
790 changes = self._calculate_commit_id_changes(
790 changes = self._calculate_commit_id_changes(
791 old_commit_ids, new_commit_ids)
791 old_commit_ids, new_commit_ids)
792 file_changes = self._calculate_file_changes(
792 file_changes = self._calculate_file_changes(
793 old_diff_data, new_diff_data)
793 old_diff_data, new_diff_data)
794
794
795 # set comments as outdated if DIFFS changed
795 # set comments as outdated if DIFFS changed
796 CommentsModel().outdate_comments(
796 CommentsModel().outdate_comments(
797 pull_request, old_diff_data=old_diff_data,
797 pull_request, old_diff_data=old_diff_data,
798 new_diff_data=new_diff_data)
798 new_diff_data=new_diff_data)
799
799
800 commit_changes = (changes.added or changes.removed)
800 commit_changes = (changes.added or changes.removed)
801 file_node_changes = (
801 file_node_changes = (
802 file_changes.added or file_changes.modified or file_changes.removed)
802 file_changes.added or file_changes.modified or file_changes.removed)
803 pr_has_changes = commit_changes or file_node_changes
803 pr_has_changes = commit_changes or file_node_changes
804
804
805 # Add an automatic comment to the pull request, in case
805 # Add an automatic comment to the pull request, in case
806 # anything has changed
806 # anything has changed
807 if pr_has_changes:
807 if pr_has_changes:
808 update_comment = CommentsModel().create(
808 update_comment = CommentsModel().create(
809 text=self._render_update_message(changes, file_changes),
809 text=self._render_update_message(changes, file_changes),
810 repo=pull_request.target_repo,
810 repo=pull_request.target_repo,
811 user=pull_request.author,
811 user=pull_request.author,
812 pull_request=pull_request,
812 pull_request=pull_request,
813 send_email=False, renderer=DEFAULT_COMMENTS_RENDERER)
813 send_email=False, renderer=DEFAULT_COMMENTS_RENDERER)
814
814
815 # Update status to "Under Review" for added commits
815 # Update status to "Under Review" for added commits
816 for commit_id in changes.added:
816 for commit_id in changes.added:
817 ChangesetStatusModel().set_status(
817 ChangesetStatusModel().set_status(
818 repo=pull_request.source_repo,
818 repo=pull_request.source_repo,
819 status=ChangesetStatus.STATUS_UNDER_REVIEW,
819 status=ChangesetStatus.STATUS_UNDER_REVIEW,
820 comment=update_comment,
820 comment=update_comment,
821 user=pull_request.author,
821 user=pull_request.author,
822 pull_request=pull_request,
822 pull_request=pull_request,
823 revision=commit_id)
823 revision=commit_id)
824
824
825 log.debug(
825 log.debug(
826 'Updated pull request %s, added_ids: %s, common_ids: %s, '
826 'Updated pull request %s, added_ids: %s, common_ids: %s, '
827 'removed_ids: %s', pull_request.pull_request_id,
827 'removed_ids: %s', pull_request.pull_request_id,
828 changes.added, changes.common, changes.removed)
828 changes.added, changes.common, changes.removed)
829 log.debug(
829 log.debug(
830 'Updated pull request with the following file changes: %s',
830 'Updated pull request with the following file changes: %s',
831 file_changes)
831 file_changes)
832
832
833 log.info(
833 log.info(
834 "Updated pull request %s from commit %s to commit %s, "
834 "Updated pull request %s from commit %s to commit %s, "
835 "stored new version %s of this pull request.",
835 "stored new version %s of this pull request.",
836 pull_request.pull_request_id, source_ref_id,
836 pull_request.pull_request_id, source_ref_id,
837 pull_request.source_ref_parts.commit_id,
837 pull_request.source_ref_parts.commit_id,
838 pull_request_version.pull_request_version_id)
838 pull_request_version.pull_request_version_id)
839 Session().commit()
839 Session().commit()
840 self._trigger_pull_request_hook(
840 self._trigger_pull_request_hook(
841 pull_request, pull_request.author, 'update')
841 pull_request, pull_request.author, 'update')
842
842
843 return UpdateResponse(
843 return UpdateResponse(
844 executed=True, reason=UpdateFailureReason.NONE,
844 executed=True, reason=UpdateFailureReason.NONE,
845 old=pull_request, new=pull_request_version, changes=changes,
845 old=pull_request, new=pull_request_version, changes=changes,
846 source_changed=source_changed, target_changed=target_changed)
846 source_changed=source_changed, target_changed=target_changed)
847
847
848 def _create_version_from_snapshot(self, pull_request):
848 def _create_version_from_snapshot(self, pull_request):
849 version = PullRequestVersion()
849 version = PullRequestVersion()
850 version.title = pull_request.title
850 version.title = pull_request.title
851 version.description = pull_request.description
851 version.description = pull_request.description
852 version.status = pull_request.status
852 version.status = pull_request.status
853 version.created_on = datetime.datetime.now()
853 version.created_on = datetime.datetime.now()
854 version.updated_on = pull_request.updated_on
854 version.updated_on = pull_request.updated_on
855 version.user_id = pull_request.user_id
855 version.user_id = pull_request.user_id
856 version.source_repo = pull_request.source_repo
856 version.source_repo = pull_request.source_repo
857 version.source_ref = pull_request.source_ref
857 version.source_ref = pull_request.source_ref
858 version.target_repo = pull_request.target_repo
858 version.target_repo = pull_request.target_repo
859 version.target_ref = pull_request.target_ref
859 version.target_ref = pull_request.target_ref
860
860
861 version._last_merge_source_rev = pull_request._last_merge_source_rev
861 version._last_merge_source_rev = pull_request._last_merge_source_rev
862 version._last_merge_target_rev = pull_request._last_merge_target_rev
862 version._last_merge_target_rev = pull_request._last_merge_target_rev
863 version.last_merge_status = pull_request.last_merge_status
863 version.last_merge_status = pull_request.last_merge_status
864 version.shadow_merge_ref = pull_request.shadow_merge_ref
864 version.shadow_merge_ref = pull_request.shadow_merge_ref
865 version.merge_rev = pull_request.merge_rev
865 version.merge_rev = pull_request.merge_rev
866 version.reviewer_data = pull_request.reviewer_data
866 version.reviewer_data = pull_request.reviewer_data
867
867
868 version.revisions = pull_request.revisions
868 version.revisions = pull_request.revisions
869 version.pull_request = pull_request
869 version.pull_request = pull_request
870 Session().add(version)
870 Session().add(version)
871 Session().flush()
871 Session().flush()
872
872
873 return version
873 return version
874
874
875 def _generate_update_diffs(self, pull_request, pull_request_version):
875 def _generate_update_diffs(self, pull_request, pull_request_version):
876
876
877 diff_context = (
877 diff_context = (
878 self.DIFF_CONTEXT +
878 self.DIFF_CONTEXT +
879 CommentsModel.needed_extra_diff_context())
879 CommentsModel.needed_extra_diff_context())
880 hide_whitespace_changes = False
880 hide_whitespace_changes = False
881 source_repo = pull_request_version.source_repo
881 source_repo = pull_request_version.source_repo
882 source_ref_id = pull_request_version.source_ref_parts.commit_id
882 source_ref_id = pull_request_version.source_ref_parts.commit_id
883 target_ref_id = pull_request_version.target_ref_parts.commit_id
883 target_ref_id = pull_request_version.target_ref_parts.commit_id
884 old_diff = self._get_diff_from_pr_or_version(
884 old_diff = self._get_diff_from_pr_or_version(
885 source_repo, source_ref_id, target_ref_id,
885 source_repo, source_ref_id, target_ref_id,
886 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
886 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
887
887
888 source_repo = pull_request.source_repo
888 source_repo = pull_request.source_repo
889 source_ref_id = pull_request.source_ref_parts.commit_id
889 source_ref_id = pull_request.source_ref_parts.commit_id
890 target_ref_id = pull_request.target_ref_parts.commit_id
890 target_ref_id = pull_request.target_ref_parts.commit_id
891
891
892 new_diff = self._get_diff_from_pr_or_version(
892 new_diff = self._get_diff_from_pr_or_version(
893 source_repo, source_ref_id, target_ref_id,
893 source_repo, source_ref_id, target_ref_id,
894 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
894 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
895
895
896 old_diff_data = diffs.DiffProcessor(old_diff)
896 old_diff_data = diffs.DiffProcessor(old_diff)
897 old_diff_data.prepare()
897 old_diff_data.prepare()
898 new_diff_data = diffs.DiffProcessor(new_diff)
898 new_diff_data = diffs.DiffProcessor(new_diff)
899 new_diff_data.prepare()
899 new_diff_data.prepare()
900
900
901 return old_diff_data, new_diff_data
901 return old_diff_data, new_diff_data
902
902
903 def _link_comments_to_version(self, pull_request_version):
903 def _link_comments_to_version(self, pull_request_version):
904 """
904 """
905 Link all unlinked comments of this pull request to the given version.
905 Link all unlinked comments of this pull request to the given version.
906
906
907 :param pull_request_version: The `PullRequestVersion` to which
907 :param pull_request_version: The `PullRequestVersion` to which
908 the comments shall be linked.
908 the comments shall be linked.
909
909
910 """
910 """
911 pull_request = pull_request_version.pull_request
911 pull_request = pull_request_version.pull_request
912 comments = ChangesetComment.query()\
912 comments = ChangesetComment.query()\
913 .filter(
913 .filter(
914 # TODO: johbo: Should we query for the repo at all here?
914 # TODO: johbo: Should we query for the repo at all here?
915 # Pending decision on how comments of PRs are to be related
915 # Pending decision on how comments of PRs are to be related
916 # to either the source repo, the target repo or no repo at all.
916 # to either the source repo, the target repo or no repo at all.
917 ChangesetComment.repo_id == pull_request.target_repo.repo_id,
917 ChangesetComment.repo_id == pull_request.target_repo.repo_id,
918 ChangesetComment.pull_request == pull_request,
918 ChangesetComment.pull_request == pull_request,
919 ChangesetComment.pull_request_version == None)\
919 ChangesetComment.pull_request_version == None)\
920 .order_by(ChangesetComment.comment_id.asc())
920 .order_by(ChangesetComment.comment_id.asc())
921
921
922 # TODO: johbo: Find out why this breaks if it is done in a bulk
922 # TODO: johbo: Find out why this breaks if it is done in a bulk
923 # operation.
923 # operation.
924 for comment in comments:
924 for comment in comments:
925 comment.pull_request_version_id = (
925 comment.pull_request_version_id = (
926 pull_request_version.pull_request_version_id)
926 pull_request_version.pull_request_version_id)
927 Session().add(comment)
927 Session().add(comment)
928
928
929 def _calculate_commit_id_changes(self, old_ids, new_ids):
929 def _calculate_commit_id_changes(self, old_ids, new_ids):
930 added = [x for x in new_ids if x not in old_ids]
930 added = [x for x in new_ids if x not in old_ids]
931 common = [x for x in new_ids if x in old_ids]
931 common = [x for x in new_ids if x in old_ids]
932 removed = [x for x in old_ids if x not in new_ids]
932 removed = [x for x in old_ids if x not in new_ids]
933 total = new_ids
933 total = new_ids
934 return ChangeTuple(added, common, removed, total)
934 return ChangeTuple(added, common, removed, total)
935
935
936 def _calculate_file_changes(self, old_diff_data, new_diff_data):
936 def _calculate_file_changes(self, old_diff_data, new_diff_data):
937
937
938 old_files = OrderedDict()
938 old_files = OrderedDict()
939 for diff_data in old_diff_data.parsed_diff:
939 for diff_data in old_diff_data.parsed_diff:
940 old_files[diff_data['filename']] = md5_safe(diff_data['raw_diff'])
940 old_files[diff_data['filename']] = md5_safe(diff_data['raw_diff'])
941
941
942 added_files = []
942 added_files = []
943 modified_files = []
943 modified_files = []
944 removed_files = []
944 removed_files = []
945 for diff_data in new_diff_data.parsed_diff:
945 for diff_data in new_diff_data.parsed_diff:
946 new_filename = diff_data['filename']
946 new_filename = diff_data['filename']
947 new_hash = md5_safe(diff_data['raw_diff'])
947 new_hash = md5_safe(diff_data['raw_diff'])
948
948
949 old_hash = old_files.get(new_filename)
949 old_hash = old_files.get(new_filename)
950 if not old_hash:
950 if not old_hash:
951 # file is not present in old diff, means it's added
951 # file is not present in old diff, means it's added
952 added_files.append(new_filename)
952 added_files.append(new_filename)
953 else:
953 else:
954 if new_hash != old_hash:
954 if new_hash != old_hash:
955 modified_files.append(new_filename)
955 modified_files.append(new_filename)
956 # now remove a file from old, since we have seen it already
956 # now remove a file from old, since we have seen it already
957 del old_files[new_filename]
957 del old_files[new_filename]
958
958
959 # removed files is when there are present in old, but not in NEW,
959 # removed files is when there are present in old, but not in NEW,
960 # since we remove old files that are present in new diff, left-overs
960 # since we remove old files that are present in new diff, left-overs
961 # if any should be the removed files
961 # if any should be the removed files
962 removed_files.extend(old_files.keys())
962 removed_files.extend(old_files.keys())
963
963
964 return FileChangeTuple(added_files, modified_files, removed_files)
964 return FileChangeTuple(added_files, modified_files, removed_files)
965
965
966 def _render_update_message(self, changes, file_changes):
966 def _render_update_message(self, changes, file_changes):
967 """
967 """
968 render the message using DEFAULT_COMMENTS_RENDERER (RST renderer),
968 render the message using DEFAULT_COMMENTS_RENDERER (RST renderer),
969 so it's always looking the same disregarding on which default
969 so it's always looking the same disregarding on which default
970 renderer system is using.
970 renderer system is using.
971
971
972 :param changes: changes named tuple
972 :param changes: changes named tuple
973 :param file_changes: file changes named tuple
973 :param file_changes: file changes named tuple
974
974
975 """
975 """
976 new_status = ChangesetStatus.get_status_lbl(
976 new_status = ChangesetStatus.get_status_lbl(
977 ChangesetStatus.STATUS_UNDER_REVIEW)
977 ChangesetStatus.STATUS_UNDER_REVIEW)
978
978
979 changed_files = (
979 changed_files = (
980 file_changes.added + file_changes.modified + file_changes.removed)
980 file_changes.added + file_changes.modified + file_changes.removed)
981
981
982 params = {
982 params = {
983 'under_review_label': new_status,
983 'under_review_label': new_status,
984 'added_commits': changes.added,
984 'added_commits': changes.added,
985 'removed_commits': changes.removed,
985 'removed_commits': changes.removed,
986 'changed_files': changed_files,
986 'changed_files': changed_files,
987 'added_files': file_changes.added,
987 'added_files': file_changes.added,
988 'modified_files': file_changes.modified,
988 'modified_files': file_changes.modified,
989 'removed_files': file_changes.removed,
989 'removed_files': file_changes.removed,
990 }
990 }
991 renderer = RstTemplateRenderer()
991 renderer = RstTemplateRenderer()
992 return renderer.render('pull_request_update.mako', **params)
992 return renderer.render('pull_request_update.mako', **params)
993
993
994 def edit(self, pull_request, title, description, description_renderer, user):
994 def edit(self, pull_request, title, description, description_renderer, user):
995 pull_request = self.__get_pull_request(pull_request)
995 pull_request = self.__get_pull_request(pull_request)
996 old_data = pull_request.get_api_data(with_merge_state=False)
996 old_data = pull_request.get_api_data(with_merge_state=False)
997 if pull_request.is_closed():
997 if pull_request.is_closed():
998 raise ValueError('This pull request is closed')
998 raise ValueError('This pull request is closed')
999 if title:
999 if title:
1000 pull_request.title = title
1000 pull_request.title = title
1001 pull_request.description = description
1001 pull_request.description = description
1002 pull_request.updated_on = datetime.datetime.now()
1002 pull_request.updated_on = datetime.datetime.now()
1003 pull_request.description_renderer = description_renderer
1003 pull_request.description_renderer = description_renderer
1004 Session().add(pull_request)
1004 Session().add(pull_request)
1005 self._log_audit_action(
1005 self._log_audit_action(
1006 'repo.pull_request.edit', {'old_data': old_data},
1006 'repo.pull_request.edit', {'old_data': old_data},
1007 user, pull_request)
1007 user, pull_request)
1008
1008
1009 def update_reviewers(self, pull_request, reviewer_data, user):
1009 def update_reviewers(self, pull_request, reviewer_data, user):
1010 """
1010 """
1011 Update the reviewers in the pull request
1011 Update the reviewers in the pull request
1012
1012
1013 :param pull_request: the pr to update
1013 :param pull_request: the pr to update
1014 :param reviewer_data: list of tuples
1014 :param reviewer_data: list of tuples
1015 [(user, ['reason1', 'reason2'], mandatory_flag, [rules])]
1015 [(user, ['reason1', 'reason2'], mandatory_flag, [rules])]
1016 """
1016 """
1017 pull_request = self.__get_pull_request(pull_request)
1017 pull_request = self.__get_pull_request(pull_request)
1018 if pull_request.is_closed():
1018 if pull_request.is_closed():
1019 raise ValueError('This pull request is closed')
1019 raise ValueError('This pull request is closed')
1020
1020
1021 reviewers = {}
1021 reviewers = {}
1022 for user_id, reasons, mandatory, rules in reviewer_data:
1022 for user_id, reasons, mandatory, rules in reviewer_data:
1023 if isinstance(user_id, (int, basestring)):
1023 if isinstance(user_id, (int, basestring)):
1024 user_id = self._get_user(user_id).user_id
1024 user_id = self._get_user(user_id).user_id
1025 reviewers[user_id] = {
1025 reviewers[user_id] = {
1026 'reasons': reasons, 'mandatory': mandatory}
1026 'reasons': reasons, 'mandatory': mandatory}
1027
1027
1028 reviewers_ids = set(reviewers.keys())
1028 reviewers_ids = set(reviewers.keys())
1029 current_reviewers = PullRequestReviewers.query()\
1029 current_reviewers = PullRequestReviewers.query()\
1030 .filter(PullRequestReviewers.pull_request ==
1030 .filter(PullRequestReviewers.pull_request ==
1031 pull_request).all()
1031 pull_request).all()
1032 current_reviewers_ids = set([x.user.user_id for x in current_reviewers])
1032 current_reviewers_ids = set([x.user.user_id for x in current_reviewers])
1033
1033
1034 ids_to_add = reviewers_ids.difference(current_reviewers_ids)
1034 ids_to_add = reviewers_ids.difference(current_reviewers_ids)
1035 ids_to_remove = current_reviewers_ids.difference(reviewers_ids)
1035 ids_to_remove = current_reviewers_ids.difference(reviewers_ids)
1036
1036
1037 log.debug("Adding %s reviewers", ids_to_add)
1037 log.debug("Adding %s reviewers", ids_to_add)
1038 log.debug("Removing %s reviewers", ids_to_remove)
1038 log.debug("Removing %s reviewers", ids_to_remove)
1039 changed = False
1039 changed = False
1040 for uid in ids_to_add:
1040 for uid in ids_to_add:
1041 changed = True
1041 changed = True
1042 _usr = self._get_user(uid)
1042 _usr = self._get_user(uid)
1043 reviewer = PullRequestReviewers()
1043 reviewer = PullRequestReviewers()
1044 reviewer.user = _usr
1044 reviewer.user = _usr
1045 reviewer.pull_request = pull_request
1045 reviewer.pull_request = pull_request
1046 reviewer.reasons = reviewers[uid]['reasons']
1046 reviewer.reasons = reviewers[uid]['reasons']
1047 # NOTE(marcink): mandatory shouldn't be changed now
1047 # NOTE(marcink): mandatory shouldn't be changed now
1048 # reviewer.mandatory = reviewers[uid]['reasons']
1048 # reviewer.mandatory = reviewers[uid]['reasons']
1049 Session().add(reviewer)
1049 Session().add(reviewer)
1050 self._log_audit_action(
1050 self._log_audit_action(
1051 'repo.pull_request.reviewer.add', {'data': reviewer.get_dict()},
1051 'repo.pull_request.reviewer.add', {'data': reviewer.get_dict()},
1052 user, pull_request)
1052 user, pull_request)
1053
1053
1054 for uid in ids_to_remove:
1054 for uid in ids_to_remove:
1055 changed = True
1055 changed = True
1056 reviewers = PullRequestReviewers.query()\
1056 reviewers = PullRequestReviewers.query()\
1057 .filter(PullRequestReviewers.user_id == uid,
1057 .filter(PullRequestReviewers.user_id == uid,
1058 PullRequestReviewers.pull_request == pull_request)\
1058 PullRequestReviewers.pull_request == pull_request)\
1059 .all()
1059 .all()
1060 # use .all() in case we accidentally added the same person twice
1060 # use .all() in case we accidentally added the same person twice
1061 # this CAN happen due to the lack of DB checks
1061 # this CAN happen due to the lack of DB checks
1062 for obj in reviewers:
1062 for obj in reviewers:
1063 old_data = obj.get_dict()
1063 old_data = obj.get_dict()
1064 Session().delete(obj)
1064 Session().delete(obj)
1065 self._log_audit_action(
1065 self._log_audit_action(
1066 'repo.pull_request.reviewer.delete',
1066 'repo.pull_request.reviewer.delete',
1067 {'old_data': old_data}, user, pull_request)
1067 {'old_data': old_data}, user, pull_request)
1068
1068
1069 if changed:
1069 if changed:
1070 pull_request.updated_on = datetime.datetime.now()
1070 pull_request.updated_on = datetime.datetime.now()
1071 Session().add(pull_request)
1071 Session().add(pull_request)
1072
1072
1073 self.notify_reviewers(pull_request, ids_to_add)
1073 self.notify_reviewers(pull_request, ids_to_add)
1074 return ids_to_add, ids_to_remove
1074 return ids_to_add, ids_to_remove
1075
1075
1076 def get_url(self, pull_request, request=None, permalink=False):
1076 def get_url(self, pull_request, request=None, permalink=False):
1077 if not request:
1077 if not request:
1078 request = get_current_request()
1078 request = get_current_request()
1079
1079
1080 if permalink:
1080 if permalink:
1081 return request.route_url(
1081 return request.route_url(
1082 'pull_requests_global',
1082 'pull_requests_global',
1083 pull_request_id=pull_request.pull_request_id,)
1083 pull_request_id=pull_request.pull_request_id,)
1084 else:
1084 else:
1085 return request.route_url('pullrequest_show',
1085 return request.route_url('pullrequest_show',
1086 repo_name=safe_str(pull_request.target_repo.repo_name),
1086 repo_name=safe_str(pull_request.target_repo.repo_name),
1087 pull_request_id=pull_request.pull_request_id,)
1087 pull_request_id=pull_request.pull_request_id,)
1088
1088
1089 def get_shadow_clone_url(self, pull_request, request=None):
1089 def get_shadow_clone_url(self, pull_request, request=None):
1090 """
1090 """
1091 Returns qualified url pointing to the shadow repository. If this pull
1091 Returns qualified url pointing to the shadow repository. If this pull
1092 request is closed there is no shadow repository and ``None`` will be
1092 request is closed there is no shadow repository and ``None`` will be
1093 returned.
1093 returned.
1094 """
1094 """
1095 if pull_request.is_closed():
1095 if pull_request.is_closed():
1096 return None
1096 return None
1097 else:
1097 else:
1098 pr_url = urllib.unquote(self.get_url(pull_request, request=request))
1098 pr_url = urllib.unquote(self.get_url(pull_request, request=request))
1099 return safe_unicode('{pr_url}/repository'.format(pr_url=pr_url))
1099 return safe_unicode('{pr_url}/repository'.format(pr_url=pr_url))
1100
1100
1101 def notify_reviewers(self, pull_request, reviewers_ids):
1101 def notify_reviewers(self, pull_request, reviewers_ids):
1102 # notification to reviewers
1102 # notification to reviewers
1103 if not reviewers_ids:
1103 if not reviewers_ids:
1104 return
1104 return
1105
1105
1106 pull_request_obj = pull_request
1106 pull_request_obj = pull_request
1107 # get the current participants of this pull request
1107 # get the current participants of this pull request
1108 recipients = reviewers_ids
1108 recipients = reviewers_ids
1109 notification_type = EmailNotificationModel.TYPE_PULL_REQUEST
1109 notification_type = EmailNotificationModel.TYPE_PULL_REQUEST
1110
1110
1111 pr_source_repo = pull_request_obj.source_repo
1111 pr_source_repo = pull_request_obj.source_repo
1112 pr_target_repo = pull_request_obj.target_repo
1112 pr_target_repo = pull_request_obj.target_repo
1113
1113
1114 pr_url = h.route_url('pullrequest_show',
1114 pr_url = h.route_url('pullrequest_show',
1115 repo_name=pr_target_repo.repo_name,
1115 repo_name=pr_target_repo.repo_name,
1116 pull_request_id=pull_request_obj.pull_request_id,)
1116 pull_request_id=pull_request_obj.pull_request_id,)
1117
1117
1118 # set some variables for email notification
1118 # set some variables for email notification
1119 pr_target_repo_url = h.route_url(
1119 pr_target_repo_url = h.route_url(
1120 'repo_summary', repo_name=pr_target_repo.repo_name)
1120 'repo_summary', repo_name=pr_target_repo.repo_name)
1121
1121
1122 pr_source_repo_url = h.route_url(
1122 pr_source_repo_url = h.route_url(
1123 'repo_summary', repo_name=pr_source_repo.repo_name)
1123 'repo_summary', repo_name=pr_source_repo.repo_name)
1124
1124
1125 # pull request specifics
1125 # pull request specifics
1126 pull_request_commits = [
1126 pull_request_commits = [
1127 (x.raw_id, x.message)
1127 (x.raw_id, x.message)
1128 for x in map(pr_source_repo.get_commit, pull_request.revisions)]
1128 for x in map(pr_source_repo.get_commit, pull_request.revisions)]
1129
1129
1130 kwargs = {
1130 kwargs = {
1131 'user': pull_request.author,
1131 'user': pull_request.author,
1132 'pull_request': pull_request_obj,
1132 'pull_request': pull_request_obj,
1133 'pull_request_commits': pull_request_commits,
1133 'pull_request_commits': pull_request_commits,
1134
1134
1135 'pull_request_target_repo': pr_target_repo,
1135 'pull_request_target_repo': pr_target_repo,
1136 'pull_request_target_repo_url': pr_target_repo_url,
1136 'pull_request_target_repo_url': pr_target_repo_url,
1137
1137
1138 'pull_request_source_repo': pr_source_repo,
1138 'pull_request_source_repo': pr_source_repo,
1139 'pull_request_source_repo_url': pr_source_repo_url,
1139 'pull_request_source_repo_url': pr_source_repo_url,
1140
1140
1141 'pull_request_url': pr_url,
1141 'pull_request_url': pr_url,
1142 }
1142 }
1143
1143
1144 # pre-generate the subject for notification itself
1144 # pre-generate the subject for notification itself
1145 (subject,
1145 (subject,
1146 _h, _e, # we don't care about those
1146 _h, _e, # we don't care about those
1147 body_plaintext) = EmailNotificationModel().render_email(
1147 body_plaintext) = EmailNotificationModel().render_email(
1148 notification_type, **kwargs)
1148 notification_type, **kwargs)
1149
1149
1150 # create notification objects, and emails
1150 # create notification objects, and emails
1151 NotificationModel().create(
1151 NotificationModel().create(
1152 created_by=pull_request.author,
1152 created_by=pull_request.author,
1153 notification_subject=subject,
1153 notification_subject=subject,
1154 notification_body=body_plaintext,
1154 notification_body=body_plaintext,
1155 notification_type=notification_type,
1155 notification_type=notification_type,
1156 recipients=recipients,
1156 recipients=recipients,
1157 email_kwargs=kwargs,
1157 email_kwargs=kwargs,
1158 )
1158 )
1159
1159
1160 def delete(self, pull_request, user):
1160 def delete(self, pull_request, user):
1161 pull_request = self.__get_pull_request(pull_request)
1161 pull_request = self.__get_pull_request(pull_request)
1162 old_data = pull_request.get_api_data(with_merge_state=False)
1162 old_data = pull_request.get_api_data(with_merge_state=False)
1163 self._cleanup_merge_workspace(pull_request)
1163 self._cleanup_merge_workspace(pull_request)
1164 self._log_audit_action(
1164 self._log_audit_action(
1165 'repo.pull_request.delete', {'old_data': old_data},
1165 'repo.pull_request.delete', {'old_data': old_data},
1166 user, pull_request)
1166 user, pull_request)
1167 Session().delete(pull_request)
1167 Session().delete(pull_request)
1168
1168
1169 def close_pull_request(self, pull_request, user):
1169 def close_pull_request(self, pull_request, user):
1170 pull_request = self.__get_pull_request(pull_request)
1170 pull_request = self.__get_pull_request(pull_request)
1171 self._cleanup_merge_workspace(pull_request)
1171 self._cleanup_merge_workspace(pull_request)
1172 pull_request.status = PullRequest.STATUS_CLOSED
1172 pull_request.status = PullRequest.STATUS_CLOSED
1173 pull_request.updated_on = datetime.datetime.now()
1173 pull_request.updated_on = datetime.datetime.now()
1174 Session().add(pull_request)
1174 Session().add(pull_request)
1175 self._trigger_pull_request_hook(
1175 self._trigger_pull_request_hook(
1176 pull_request, pull_request.author, 'close')
1176 pull_request, pull_request.author, 'close')
1177
1177
1178 pr_data = pull_request.get_api_data(with_merge_state=False)
1178 pr_data = pull_request.get_api_data(with_merge_state=False)
1179 self._log_audit_action(
1179 self._log_audit_action(
1180 'repo.pull_request.close', {'data': pr_data}, user, pull_request)
1180 'repo.pull_request.close', {'data': pr_data}, user, pull_request)
1181
1181
1182 def close_pull_request_with_comment(
1182 def close_pull_request_with_comment(
1183 self, pull_request, user, repo, message=None, auth_user=None):
1183 self, pull_request, user, repo, message=None, auth_user=None):
1184
1184
1185 pull_request_review_status = pull_request.calculated_review_status()
1185 pull_request_review_status = pull_request.calculated_review_status()
1186
1186
1187 if pull_request_review_status == ChangesetStatus.STATUS_APPROVED:
1187 if pull_request_review_status == ChangesetStatus.STATUS_APPROVED:
1188 # approved only if we have voting consent
1188 # approved only if we have voting consent
1189 status = ChangesetStatus.STATUS_APPROVED
1189 status = ChangesetStatus.STATUS_APPROVED
1190 else:
1190 else:
1191 status = ChangesetStatus.STATUS_REJECTED
1191 status = ChangesetStatus.STATUS_REJECTED
1192 status_lbl = ChangesetStatus.get_status_lbl(status)
1192 status_lbl = ChangesetStatus.get_status_lbl(status)
1193
1193
1194 default_message = (
1194 default_message = (
1195 'Closing with status change {transition_icon} {status}.'
1195 'Closing with status change {transition_icon} {status}.'
1196 ).format(transition_icon='>', status=status_lbl)
1196 ).format(transition_icon='>', status=status_lbl)
1197 text = message or default_message
1197 text = message or default_message
1198
1198
1199 # create a comment, and link it to new status
1199 # create a comment, and link it to new status
1200 comment = CommentsModel().create(
1200 comment = CommentsModel().create(
1201 text=text,
1201 text=text,
1202 repo=repo.repo_id,
1202 repo=repo.repo_id,
1203 user=user.user_id,
1203 user=user.user_id,
1204 pull_request=pull_request.pull_request_id,
1204 pull_request=pull_request.pull_request_id,
1205 status_change=status_lbl,
1205 status_change=status_lbl,
1206 status_change_type=status,
1206 status_change_type=status,
1207 closing_pr=True,
1207 closing_pr=True,
1208 auth_user=auth_user,
1208 auth_user=auth_user,
1209 )
1209 )
1210
1210
1211 # calculate old status before we change it
1211 # calculate old status before we change it
1212 old_calculated_status = pull_request.calculated_review_status()
1212 old_calculated_status = pull_request.calculated_review_status()
1213 ChangesetStatusModel().set_status(
1213 ChangesetStatusModel().set_status(
1214 repo.repo_id,
1214 repo.repo_id,
1215 status,
1215 status,
1216 user.user_id,
1216 user.user_id,
1217 comment=comment,
1217 comment=comment,
1218 pull_request=pull_request.pull_request_id
1218 pull_request=pull_request.pull_request_id
1219 )
1219 )
1220
1220
1221 Session().flush()
1221 Session().flush()
1222 events.trigger(events.PullRequestCommentEvent(pull_request, comment))
1222 events.trigger(events.PullRequestCommentEvent(pull_request, comment))
1223 # we now calculate the status of pull request again, and based on that
1223 # we now calculate the status of pull request again, and based on that
1224 # calculation trigger status change. This might happen in cases
1224 # calculation trigger status change. This might happen in cases
1225 # that non-reviewer admin closes a pr, which means his vote doesn't
1225 # that non-reviewer admin closes a pr, which means his vote doesn't
1226 # change the status, while if he's a reviewer this might change it.
1226 # change the status, while if he's a reviewer this might change it.
1227 calculated_status = pull_request.calculated_review_status()
1227 calculated_status = pull_request.calculated_review_status()
1228 if old_calculated_status != calculated_status:
1228 if old_calculated_status != calculated_status:
1229 self._trigger_pull_request_hook(
1229 self._trigger_pull_request_hook(
1230 pull_request, user, 'review_status_change')
1230 pull_request, user, 'review_status_change')
1231
1231
1232 # finally close the PR
1232 # finally close the PR
1233 PullRequestModel().close_pull_request(
1233 PullRequestModel().close_pull_request(
1234 pull_request.pull_request_id, user)
1234 pull_request.pull_request_id, user)
1235
1235
1236 return comment, status
1236 return comment, status
1237
1237
1238 def merge_status(self, pull_request, translator=None,
1238 def merge_status(self, pull_request, translator=None,
1239 force_shadow_repo_refresh=False):
1239 force_shadow_repo_refresh=False):
1240 _ = translator or get_current_request().translate
1240 _ = translator or get_current_request().translate
1241
1241
1242 if not self._is_merge_enabled(pull_request):
1242 if not self._is_merge_enabled(pull_request):
1243 return False, _('Server-side pull request merging is disabled.')
1243 return False, _('Server-side pull request merging is disabled.')
1244 if pull_request.is_closed():
1244 if pull_request.is_closed():
1245 return False, _('This pull request is closed.')
1245 return False, _('This pull request is closed.')
1246 merge_possible, msg = self._check_repo_requirements(
1246 merge_possible, msg = self._check_repo_requirements(
1247 target=pull_request.target_repo, source=pull_request.source_repo,
1247 target=pull_request.target_repo, source=pull_request.source_repo,
1248 translator=_)
1248 translator=_)
1249 if not merge_possible:
1249 if not merge_possible:
1250 return merge_possible, msg
1250 return merge_possible, msg
1251
1251
1252 try:
1252 try:
1253 resp = self._try_merge(
1253 resp = self._try_merge(
1254 pull_request,
1254 pull_request,
1255 force_shadow_repo_refresh=force_shadow_repo_refresh)
1255 force_shadow_repo_refresh=force_shadow_repo_refresh)
1256 log.debug("Merge response: %s", resp)
1256 log.debug("Merge response: %s", resp)
1257 status = resp.possible, self.merge_status_message(
1257 status = resp.possible, self.merge_status_message(
1258 resp.failure_reason)
1258 resp.failure_reason)
1259 except NotImplementedError:
1259 except NotImplementedError:
1260 status = False, _('Pull request merging is not supported.')
1260 status = False, _('Pull request merging is not supported.')
1261
1261
1262 return status
1262 return status
1263
1263
1264 def _check_repo_requirements(self, target, source, translator):
1264 def _check_repo_requirements(self, target, source, translator):
1265 """
1265 """
1266 Check if `target` and `source` have compatible requirements.
1266 Check if `target` and `source` have compatible requirements.
1267
1267
1268 Currently this is just checking for largefiles.
1268 Currently this is just checking for largefiles.
1269 """
1269 """
1270 _ = translator
1270 _ = translator
1271 target_has_largefiles = self._has_largefiles(target)
1271 target_has_largefiles = self._has_largefiles(target)
1272 source_has_largefiles = self._has_largefiles(source)
1272 source_has_largefiles = self._has_largefiles(source)
1273 merge_possible = True
1273 merge_possible = True
1274 message = u''
1274 message = u''
1275
1275
1276 if target_has_largefiles != source_has_largefiles:
1276 if target_has_largefiles != source_has_largefiles:
1277 merge_possible = False
1277 merge_possible = False
1278 if source_has_largefiles:
1278 if source_has_largefiles:
1279 message = _(
1279 message = _(
1280 'Target repository large files support is disabled.')
1280 'Target repository large files support is disabled.')
1281 else:
1281 else:
1282 message = _(
1282 message = _(
1283 'Source repository large files support is disabled.')
1283 'Source repository large files support is disabled.')
1284
1284
1285 return merge_possible, message
1285 return merge_possible, message
1286
1286
1287 def _has_largefiles(self, repo):
1287 def _has_largefiles(self, repo):
1288 largefiles_ui = VcsSettingsModel(repo=repo).get_ui_settings(
1288 largefiles_ui = VcsSettingsModel(repo=repo).get_ui_settings(
1289 'extensions', 'largefiles')
1289 'extensions', 'largefiles')
1290 return largefiles_ui and largefiles_ui[0].active
1290 return largefiles_ui and largefiles_ui[0].active
1291
1291
1292 def _try_merge(self, pull_request, force_shadow_repo_refresh=False):
1292 def _try_merge(self, pull_request, force_shadow_repo_refresh=False):
1293 """
1293 """
1294 Try to merge the pull request and return the merge status.
1294 Try to merge the pull request and return the merge status.
1295 """
1295 """
1296 log.debug(
1296 log.debug(
1297 "Trying out if the pull request %s can be merged. Force_refresh=%s",
1297 "Trying out if the pull request %s can be merged. Force_refresh=%s",
1298 pull_request.pull_request_id, force_shadow_repo_refresh)
1298 pull_request.pull_request_id, force_shadow_repo_refresh)
1299 target_vcs = pull_request.target_repo.scm_instance()
1299 target_vcs = pull_request.target_repo.scm_instance()
1300
1300
1301 # Refresh the target reference.
1301 # Refresh the target reference.
1302 try:
1302 try:
1303 target_ref = self._refresh_reference(
1303 target_ref = self._refresh_reference(
1304 pull_request.target_ref_parts, target_vcs)
1304 pull_request.target_ref_parts, target_vcs)
1305 except CommitDoesNotExistError:
1305 except CommitDoesNotExistError:
1306 merge_state = MergeResponse(
1306 merge_state = MergeResponse(
1307 False, False, None, MergeFailureReason.MISSING_TARGET_REF)
1307 False, False, None, MergeFailureReason.MISSING_TARGET_REF)
1308 return merge_state
1308 return merge_state
1309
1309
1310 target_locked = pull_request.target_repo.locked
1310 target_locked = pull_request.target_repo.locked
1311 if target_locked and target_locked[0]:
1311 if target_locked and target_locked[0]:
1312 log.debug("The target repository is locked.")
1312 log.debug("The target repository is locked.")
1313 merge_state = MergeResponse(
1313 merge_state = MergeResponse(
1314 False, False, None, MergeFailureReason.TARGET_IS_LOCKED)
1314 False, False, None, MergeFailureReason.TARGET_IS_LOCKED)
1315 elif force_shadow_repo_refresh or self._needs_merge_state_refresh(
1315 elif force_shadow_repo_refresh or self._needs_merge_state_refresh(
1316 pull_request, target_ref):
1316 pull_request, target_ref):
1317 log.debug("Refreshing the merge status of the repository.")
1317 log.debug("Refreshing the merge status of the repository.")
1318 merge_state = self._refresh_merge_state(
1318 merge_state = self._refresh_merge_state(
1319 pull_request, target_vcs, target_ref)
1319 pull_request, target_vcs, target_ref)
1320 else:
1320 else:
1321 possible = pull_request.\
1321 possible = pull_request.\
1322 last_merge_status == MergeFailureReason.NONE
1322 last_merge_status == MergeFailureReason.NONE
1323 merge_state = MergeResponse(
1323 merge_state = MergeResponse(
1324 possible, False, None, pull_request.last_merge_status)
1324 possible, False, None, pull_request.last_merge_status)
1325
1325
1326 return merge_state
1326 return merge_state
1327
1327
1328 def _refresh_reference(self, reference, vcs_repository):
1328 def _refresh_reference(self, reference, vcs_repository):
1329 if reference.type in ('branch', 'book'):
1329 if reference.type in ('branch', 'book'):
1330 name_or_id = reference.name
1330 name_or_id = reference.name
1331 else:
1331 else:
1332 name_or_id = reference.commit_id
1332 name_or_id = reference.commit_id
1333 refreshed_commit = vcs_repository.get_commit(name_or_id)
1333 refreshed_commit = vcs_repository.get_commit(name_or_id)
1334 refreshed_reference = Reference(
1334 refreshed_reference = Reference(
1335 reference.type, reference.name, refreshed_commit.raw_id)
1335 reference.type, reference.name, refreshed_commit.raw_id)
1336 return refreshed_reference
1336 return refreshed_reference
1337
1337
1338 def _needs_merge_state_refresh(self, pull_request, target_reference):
1338 def _needs_merge_state_refresh(self, pull_request, target_reference):
1339 return not(
1339 return not(
1340 pull_request.revisions and
1340 pull_request.revisions and
1341 pull_request.revisions[0] == pull_request._last_merge_source_rev and
1341 pull_request.revisions[0] == pull_request._last_merge_source_rev and
1342 target_reference.commit_id == pull_request._last_merge_target_rev)
1342 target_reference.commit_id == pull_request._last_merge_target_rev)
1343
1343
1344 def _refresh_merge_state(self, pull_request, target_vcs, target_reference):
1344 def _refresh_merge_state(self, pull_request, target_vcs, target_reference):
1345 workspace_id = self._workspace_id(pull_request)
1345 workspace_id = self._workspace_id(pull_request)
1346 source_vcs = pull_request.source_repo.scm_instance()
1346 source_vcs = pull_request.source_repo.scm_instance()
1347 repo_id = pull_request.target_repo.repo_id
1347 repo_id = pull_request.target_repo.repo_id
1348 use_rebase = self._use_rebase_for_merging(pull_request)
1348 use_rebase = self._use_rebase_for_merging(pull_request)
1349 close_branch = self._close_branch_before_merging(pull_request)
1349 close_branch = self._close_branch_before_merging(pull_request)
1350 merge_state = target_vcs.merge(
1350 merge_state = target_vcs.merge(
1351 repo_id, workspace_id,
1351 repo_id, workspace_id,
1352 target_reference, source_vcs, pull_request.source_ref_parts,
1352 target_reference, source_vcs, pull_request.source_ref_parts,
1353 dry_run=True, use_rebase=use_rebase,
1353 dry_run=True, use_rebase=use_rebase,
1354 close_branch=close_branch)
1354 close_branch=close_branch)
1355
1355
1356 # Do not store the response if there was an unknown error.
1356 # Do not store the response if there was an unknown error.
1357 if merge_state.failure_reason != MergeFailureReason.UNKNOWN:
1357 if merge_state.failure_reason != MergeFailureReason.UNKNOWN:
1358 pull_request._last_merge_source_rev = \
1358 pull_request._last_merge_source_rev = \
1359 pull_request.source_ref_parts.commit_id
1359 pull_request.source_ref_parts.commit_id
1360 pull_request._last_merge_target_rev = target_reference.commit_id
1360 pull_request._last_merge_target_rev = target_reference.commit_id
1361 pull_request.last_merge_status = merge_state.failure_reason
1361 pull_request.last_merge_status = merge_state.failure_reason
1362 pull_request.shadow_merge_ref = merge_state.merge_ref
1362 pull_request.shadow_merge_ref = merge_state.merge_ref
1363 Session().add(pull_request)
1363 Session().add(pull_request)
1364 Session().commit()
1364 Session().commit()
1365
1365
1366 return merge_state
1366 return merge_state
1367
1367
1368 def _workspace_id(self, pull_request):
1368 def _workspace_id(self, pull_request):
1369 workspace_id = 'pr-%s' % pull_request.pull_request_id
1369 workspace_id = 'pr-%s' % pull_request.pull_request_id
1370 return workspace_id
1370 return workspace_id
1371
1371
1372 def merge_status_message(self, status_code):
1372 def merge_status_message(self, status_code):
1373 """
1373 """
1374 Return a human friendly error message for the given merge status code.
1374 Return a human friendly error message for the given merge status code.
1375 """
1375 """
1376 return self.MERGE_STATUS_MESSAGES[status_code]
1376 return self.MERGE_STATUS_MESSAGES[status_code]
1377
1377
1378 def generate_repo_data(self, repo, commit_id=None, branch=None,
1378 def generate_repo_data(self, repo, commit_id=None, branch=None,
1379 bookmark=None, translator=None):
1379 bookmark=None, translator=None):
1380 from rhodecode.model.repo import RepoModel
1380 from rhodecode.model.repo import RepoModel
1381
1381
1382 all_refs, selected_ref = \
1382 all_refs, selected_ref = \
1383 self._get_repo_pullrequest_sources(
1383 self._get_repo_pullrequest_sources(
1384 repo.scm_instance(), commit_id=commit_id,
1384 repo.scm_instance(), commit_id=commit_id,
1385 branch=branch, bookmark=bookmark, translator=translator)
1385 branch=branch, bookmark=bookmark, translator=translator)
1386
1386
1387 refs_select2 = []
1387 refs_select2 = []
1388 for element in all_refs:
1388 for element in all_refs:
1389 children = [{'id': x[0], 'text': x[1]} for x in element[0]]
1389 children = [{'id': x[0], 'text': x[1]} for x in element[0]]
1390 refs_select2.append({'text': element[1], 'children': children})
1390 refs_select2.append({'text': element[1], 'children': children})
1391
1391
1392 return {
1392 return {
1393 'user': {
1393 'user': {
1394 'user_id': repo.user.user_id,
1394 'user_id': repo.user.user_id,
1395 'username': repo.user.username,
1395 'username': repo.user.username,
1396 'firstname': repo.user.first_name,
1396 'firstname': repo.user.first_name,
1397 'lastname': repo.user.last_name,
1397 'lastname': repo.user.last_name,
1398 'gravatar_link': h.gravatar_url(repo.user.email, 14),
1398 'gravatar_link': h.gravatar_url(repo.user.email, 14),
1399 },
1399 },
1400 'name': repo.repo_name,
1400 'name': repo.repo_name,
1401 'link': RepoModel().get_url(repo),
1401 'link': RepoModel().get_url(repo),
1402 'description': h.chop_at_smart(repo.description_safe, '\n'),
1402 'description': h.chop_at_smart(repo.description_safe, '\n'),
1403 'refs': {
1403 'refs': {
1404 'all_refs': all_refs,
1404 'all_refs': all_refs,
1405 'selected_ref': selected_ref,
1405 'selected_ref': selected_ref,
1406 'select2_refs': refs_select2
1406 'select2_refs': refs_select2
1407 }
1407 }
1408 }
1408 }
1409
1409
1410 def generate_pullrequest_title(self, source, source_ref, target):
1410 def generate_pullrequest_title(self, source, source_ref, target):
1411 return u'{source}#{at_ref} to {target}'.format(
1411 return u'{source}#{at_ref} to {target}'.format(
1412 source=source,
1412 source=source,
1413 at_ref=source_ref,
1413 at_ref=source_ref,
1414 target=target,
1414 target=target,
1415 )
1415 )
1416
1416
1417 def _cleanup_merge_workspace(self, pull_request):
1417 def _cleanup_merge_workspace(self, pull_request):
1418 # Merging related cleanup
1418 # Merging related cleanup
1419 repo_id = pull_request.target_repo.repo_id
1419 repo_id = pull_request.target_repo.repo_id
1420 target_scm = pull_request.target_repo.scm_instance()
1420 target_scm = pull_request.target_repo.scm_instance()
1421 workspace_id = self._workspace_id(pull_request)
1421 workspace_id = self._workspace_id(pull_request)
1422
1422
1423 try:
1423 try:
1424 target_scm.cleanup_merge_workspace(repo_id, workspace_id)
1424 target_scm.cleanup_merge_workspace(repo_id, workspace_id)
1425 except NotImplementedError:
1425 except NotImplementedError:
1426 pass
1426 pass
1427
1427
1428 def _get_repo_pullrequest_sources(
1428 def _get_repo_pullrequest_sources(
1429 self, repo, commit_id=None, branch=None, bookmark=None,
1429 self, repo, commit_id=None, branch=None, bookmark=None,
1430 translator=None):
1430 translator=None):
1431 """
1431 """
1432 Return a structure with repo's interesting commits, suitable for
1432 Return a structure with repo's interesting commits, suitable for
1433 the selectors in pullrequest controller
1433 the selectors in pullrequest controller
1434
1434
1435 :param commit_id: a commit that must be in the list somehow
1435 :param commit_id: a commit that must be in the list somehow
1436 and selected by default
1436 and selected by default
1437 :param branch: a branch that must be in the list and selected
1437 :param branch: a branch that must be in the list and selected
1438 by default - even if closed
1438 by default - even if closed
1439 :param bookmark: a bookmark that must be in the list and selected
1439 :param bookmark: a bookmark that must be in the list and selected
1440 """
1440 """
1441 _ = translator or get_current_request().translate
1441 _ = translator or get_current_request().translate
1442
1442
1443 commit_id = safe_str(commit_id) if commit_id else None
1443 commit_id = safe_str(commit_id) if commit_id else None
1444 branch = safe_str(branch) if branch else None
1444 branch = safe_str(branch) if branch else None
1445 bookmark = safe_str(bookmark) if bookmark else None
1445 bookmark = safe_str(bookmark) if bookmark else None
1446
1446
1447 selected = None
1447 selected = None
1448
1448
1449 # order matters: first source that has commit_id in it will be selected
1449 # order matters: first source that has commit_id in it will be selected
1450 sources = []
1450 sources = []
1451 sources.append(('book', repo.bookmarks.items(), _('Bookmarks'), bookmark))
1451 sources.append(('book', repo.bookmarks.items(), _('Bookmarks'), bookmark))
1452 sources.append(('branch', repo.branches.items(), _('Branches'), branch))
1452 sources.append(('branch', repo.branches.items(), _('Branches'), branch))
1453
1453
1454 if commit_id:
1454 if commit_id:
1455 ref_commit = (h.short_id(commit_id), commit_id)
1455 ref_commit = (h.short_id(commit_id), commit_id)
1456 sources.append(('rev', [ref_commit], _('Commit IDs'), commit_id))
1456 sources.append(('rev', [ref_commit], _('Commit IDs'), commit_id))
1457
1457
1458 sources.append(
1458 sources.append(
1459 ('branch', repo.branches_closed.items(), _('Closed Branches'), branch),
1459 ('branch', repo.branches_closed.items(), _('Closed Branches'), branch),
1460 )
1460 )
1461
1461
1462 groups = []
1462 groups = []
1463 for group_key, ref_list, group_name, match in sources:
1463 for group_key, ref_list, group_name, match in sources:
1464 group_refs = []
1464 group_refs = []
1465 for ref_name, ref_id in ref_list:
1465 for ref_name, ref_id in ref_list:
1466 ref_key = '%s:%s:%s' % (group_key, ref_name, ref_id)
1466 ref_key = '%s:%s:%s' % (group_key, ref_name, ref_id)
1467 group_refs.append((ref_key, ref_name))
1467 group_refs.append((ref_key, ref_name))
1468
1468
1469 if not selected:
1469 if not selected:
1470 if set([commit_id, match]) & set([ref_id, ref_name]):
1470 if set([commit_id, match]) & set([ref_id, ref_name]):
1471 selected = ref_key
1471 selected = ref_key
1472
1472
1473 if group_refs:
1473 if group_refs:
1474 groups.append((group_refs, group_name))
1474 groups.append((group_refs, group_name))
1475
1475
1476 if not selected:
1476 if not selected:
1477 ref = commit_id or branch or bookmark
1477 ref = commit_id or branch or bookmark
1478 if ref:
1478 if ref:
1479 raise CommitDoesNotExistError(
1479 raise CommitDoesNotExistError(
1480 'No commit refs could be found matching: %s' % ref)
1480 'No commit refs could be found matching: %s' % ref)
1481 elif repo.DEFAULT_BRANCH_NAME in repo.branches:
1481 elif repo.DEFAULT_BRANCH_NAME in repo.branches:
1482 selected = 'branch:%s:%s' % (
1482 selected = 'branch:%s:%s' % (
1483 repo.DEFAULT_BRANCH_NAME,
1483 repo.DEFAULT_BRANCH_NAME,
1484 repo.branches[repo.DEFAULT_BRANCH_NAME]
1484 repo.branches[repo.DEFAULT_BRANCH_NAME]
1485 )
1485 )
1486 elif repo.commit_ids:
1486 elif repo.commit_ids:
1487 # make the user select in this case
1487 # make the user select in this case
1488 selected = None
1488 selected = None
1489 else:
1489 else:
1490 raise EmptyRepositoryError()
1490 raise EmptyRepositoryError()
1491 return groups, selected
1491 return groups, selected
1492
1492
1493 def get_diff(self, source_repo, source_ref_id, target_ref_id,
1493 def get_diff(self, source_repo, source_ref_id, target_ref_id,
1494 hide_whitespace_changes, diff_context):
1494 hide_whitespace_changes, diff_context):
1495
1495
1496 return self._get_diff_from_pr_or_version(
1496 return self._get_diff_from_pr_or_version(
1497 source_repo, source_ref_id, target_ref_id,
1497 source_repo, source_ref_id, target_ref_id,
1498 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1498 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1499
1499
1500 def _get_diff_from_pr_or_version(
1500 def _get_diff_from_pr_or_version(
1501 self, source_repo, source_ref_id, target_ref_id,
1501 self, source_repo, source_ref_id, target_ref_id,
1502 hide_whitespace_changes, diff_context):
1502 hide_whitespace_changes, diff_context):
1503
1503
1504 target_commit = source_repo.get_commit(
1504 target_commit = source_repo.get_commit(
1505 commit_id=safe_str(target_ref_id))
1505 commit_id=safe_str(target_ref_id))
1506 source_commit = source_repo.get_commit(
1506 source_commit = source_repo.get_commit(
1507 commit_id=safe_str(source_ref_id))
1507 commit_id=safe_str(source_ref_id))
1508 if isinstance(source_repo, Repository):
1508 if isinstance(source_repo, Repository):
1509 vcs_repo = source_repo.scm_instance()
1509 vcs_repo = source_repo.scm_instance()
1510 else:
1510 else:
1511 vcs_repo = source_repo
1511 vcs_repo = source_repo
1512
1512
1513 # TODO: johbo: In the context of an update, we cannot reach
1513 # TODO: johbo: In the context of an update, we cannot reach
1514 # the old commit anymore with our normal mechanisms. It needs
1514 # the old commit anymore with our normal mechanisms. It needs
1515 # some sort of special support in the vcs layer to avoid this
1515 # some sort of special support in the vcs layer to avoid this
1516 # workaround.
1516 # workaround.
1517 if (source_commit.raw_id == vcs_repo.EMPTY_COMMIT_ID and
1517 if (source_commit.raw_id == vcs_repo.EMPTY_COMMIT_ID and
1518 vcs_repo.alias == 'git'):
1518 vcs_repo.alias == 'git'):
1519 source_commit.raw_id = safe_str(source_ref_id)
1519 source_commit.raw_id = safe_str(source_ref_id)
1520
1520
1521 log.debug('calculating diff between '
1521 log.debug('calculating diff between '
1522 'source_ref:%s and target_ref:%s for repo `%s`',
1522 'source_ref:%s and target_ref:%s for repo `%s`',
1523 target_ref_id, source_ref_id,
1523 target_ref_id, source_ref_id,
1524 safe_unicode(vcs_repo.path))
1524 safe_unicode(vcs_repo.path))
1525
1525
1526 vcs_diff = vcs_repo.get_diff(
1526 vcs_diff = vcs_repo.get_diff(
1527 commit1=target_commit, commit2=source_commit,
1527 commit1=target_commit, commit2=source_commit,
1528 ignore_whitespace=hide_whitespace_changes, context=diff_context)
1528 ignore_whitespace=hide_whitespace_changes, context=diff_context)
1529 return vcs_diff
1529 return vcs_diff
1530
1530
1531 def _is_merge_enabled(self, pull_request):
1531 def _is_merge_enabled(self, pull_request):
1532 return self._get_general_setting(
1532 return self._get_general_setting(
1533 pull_request, 'rhodecode_pr_merge_enabled')
1533 pull_request, 'rhodecode_pr_merge_enabled')
1534
1534
1535 def _use_rebase_for_merging(self, pull_request):
1535 def _use_rebase_for_merging(self, pull_request):
1536 repo_type = pull_request.target_repo.repo_type
1536 repo_type = pull_request.target_repo.repo_type
1537 if repo_type == 'hg':
1537 if repo_type == 'hg':
1538 return self._get_general_setting(
1538 return self._get_general_setting(
1539 pull_request, 'rhodecode_hg_use_rebase_for_merging')
1539 pull_request, 'rhodecode_hg_use_rebase_for_merging')
1540 elif repo_type == 'git':
1540 elif repo_type == 'git':
1541 return self._get_general_setting(
1541 return self._get_general_setting(
1542 pull_request, 'rhodecode_git_use_rebase_for_merging')
1542 pull_request, 'rhodecode_git_use_rebase_for_merging')
1543
1543
1544 return False
1544 return False
1545
1545
1546 def _close_branch_before_merging(self, pull_request):
1546 def _close_branch_before_merging(self, pull_request):
1547 repo_type = pull_request.target_repo.repo_type
1547 repo_type = pull_request.target_repo.repo_type
1548 if repo_type == 'hg':
1548 if repo_type == 'hg':
1549 return self._get_general_setting(
1549 return self._get_general_setting(
1550 pull_request, 'rhodecode_hg_close_branch_before_merging')
1550 pull_request, 'rhodecode_hg_close_branch_before_merging')
1551 elif repo_type == 'git':
1551 elif repo_type == 'git':
1552 return self._get_general_setting(
1552 return self._get_general_setting(
1553 pull_request, 'rhodecode_git_close_branch_before_merging')
1553 pull_request, 'rhodecode_git_close_branch_before_merging')
1554
1554
1555 return False
1555 return False
1556
1556
1557 def _get_general_setting(self, pull_request, settings_key, default=False):
1557 def _get_general_setting(self, pull_request, settings_key, default=False):
1558 settings_model = VcsSettingsModel(repo=pull_request.target_repo)
1558 settings_model = VcsSettingsModel(repo=pull_request.target_repo)
1559 settings = settings_model.get_general_settings()
1559 settings = settings_model.get_general_settings()
1560 return settings.get(settings_key, default)
1560 return settings.get(settings_key, default)
1561
1561
1562 def _log_audit_action(self, action, action_data, user, pull_request):
1562 def _log_audit_action(self, action, action_data, user, pull_request):
1563 audit_logger.store(
1563 audit_logger.store(
1564 action=action,
1564 action=action,
1565 action_data=action_data,
1565 action_data=action_data,
1566 user=user,
1566 user=user,
1567 repo=pull_request.target_repo)
1567 repo=pull_request.target_repo)
1568
1568
1569 def get_reviewer_functions(self):
1569 def get_reviewer_functions(self):
1570 """
1570 """
1571 Fetches functions for validation and fetching default reviewers.
1571 Fetches functions for validation and fetching default reviewers.
1572 If available we use the EE package, else we fallback to CE
1572 If available we use the EE package, else we fallback to CE
1573 package functions
1573 package functions
1574 """
1574 """
1575 try:
1575 try:
1576 from rc_reviewers.utils import get_default_reviewers_data
1576 from rc_reviewers.utils import get_default_reviewers_data
1577 from rc_reviewers.utils import validate_default_reviewers
1577 from rc_reviewers.utils import validate_default_reviewers
1578 except ImportError:
1578 except ImportError:
1579 from rhodecode.apps.repository.utils import \
1579 from rhodecode.apps.repository.utils import get_default_reviewers_data
1580 get_default_reviewers_data
1580 from rhodecode.apps.repository.utils import validate_default_reviewers
1581 from rhodecode.apps.repository.utils import \
1582 validate_default_reviewers
1583
1581
1584 return get_default_reviewers_data, validate_default_reviewers
1582 return get_default_reviewers_data, validate_default_reviewers
1585
1583
1586
1584
1587 class MergeCheck(object):
1585 class MergeCheck(object):
1588 """
1586 """
1589 Perform Merge Checks and returns a check object which stores information
1587 Perform Merge Checks and returns a check object which stores information
1590 about merge errors, and merge conditions
1588 about merge errors, and merge conditions
1591 """
1589 """
1592 TODO_CHECK = 'todo'
1590 TODO_CHECK = 'todo'
1593 PERM_CHECK = 'perm'
1591 PERM_CHECK = 'perm'
1594 REVIEW_CHECK = 'review'
1592 REVIEW_CHECK = 'review'
1595 MERGE_CHECK = 'merge'
1593 MERGE_CHECK = 'merge'
1596
1594
1597 def __init__(self):
1595 def __init__(self):
1598 self.review_status = None
1596 self.review_status = None
1599 self.merge_possible = None
1597 self.merge_possible = None
1600 self.merge_msg = ''
1598 self.merge_msg = ''
1601 self.failed = None
1599 self.failed = None
1602 self.errors = []
1600 self.errors = []
1603 self.error_details = OrderedDict()
1601 self.error_details = OrderedDict()
1604
1602
1605 def push_error(self, error_type, message, error_key, details):
1603 def push_error(self, error_type, message, error_key, details):
1606 self.failed = True
1604 self.failed = True
1607 self.errors.append([error_type, message])
1605 self.errors.append([error_type, message])
1608 self.error_details[error_key] = dict(
1606 self.error_details[error_key] = dict(
1609 details=details,
1607 details=details,
1610 error_type=error_type,
1608 error_type=error_type,
1611 message=message
1609 message=message
1612 )
1610 )
1613
1611
1614 @classmethod
1612 @classmethod
1615 def validate(cls, pull_request, auth_user, translator, fail_early=False,
1613 def validate(cls, pull_request, auth_user, translator, fail_early=False,
1616 force_shadow_repo_refresh=False):
1614 force_shadow_repo_refresh=False):
1617 _ = translator
1615 _ = translator
1618 merge_check = cls()
1616 merge_check = cls()
1619
1617
1620 # permissions to merge
1618 # permissions to merge
1621 user_allowed_to_merge = PullRequestModel().check_user_merge(
1619 user_allowed_to_merge = PullRequestModel().check_user_merge(
1622 pull_request, auth_user)
1620 pull_request, auth_user)
1623 if not user_allowed_to_merge:
1621 if not user_allowed_to_merge:
1624 log.debug("MergeCheck: cannot merge, approval is pending.")
1622 log.debug("MergeCheck: cannot merge, approval is pending.")
1625
1623
1626 msg = _('User `{}` not allowed to perform merge.').format(auth_user.username)
1624 msg = _('User `{}` not allowed to perform merge.').format(auth_user.username)
1627 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
1625 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
1628 if fail_early:
1626 if fail_early:
1629 return merge_check
1627 return merge_check
1630
1628
1631 # permission to merge into the target branch
1629 # permission to merge into the target branch
1632 target_commit_id = pull_request.target_ref_parts.commit_id
1630 target_commit_id = pull_request.target_ref_parts.commit_id
1633 if pull_request.target_ref_parts.type == 'branch':
1631 if pull_request.target_ref_parts.type == 'branch':
1634 branch_name = pull_request.target_ref_parts.name
1632 branch_name = pull_request.target_ref_parts.name
1635 else:
1633 else:
1636 # for mercurial we can always figure out the branch from the commit
1634 # for mercurial we can always figure out the branch from the commit
1637 # in case of bookmark
1635 # in case of bookmark
1638 target_commit = pull_request.target_repo.get_commit(target_commit_id)
1636 target_commit = pull_request.target_repo.get_commit(target_commit_id)
1639 branch_name = target_commit.branch
1637 branch_name = target_commit.branch
1640
1638
1641 rule, branch_perm = auth_user.get_rule_and_branch_permission(
1639 rule, branch_perm = auth_user.get_rule_and_branch_permission(
1642 pull_request.target_repo.repo_name, branch_name)
1640 pull_request.target_repo.repo_name, branch_name)
1643 if branch_perm and branch_perm == 'branch.none':
1641 if branch_perm and branch_perm == 'branch.none':
1644 msg = _('Target branch `{}` changes rejected by rule {}.').format(
1642 msg = _('Target branch `{}` changes rejected by rule {}.').format(
1645 branch_name, rule)
1643 branch_name, rule)
1646 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
1644 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
1647 if fail_early:
1645 if fail_early:
1648 return merge_check
1646 return merge_check
1649
1647
1650 # review status, must be always present
1648 # review status, must be always present
1651 review_status = pull_request.calculated_review_status()
1649 review_status = pull_request.calculated_review_status()
1652 merge_check.review_status = review_status
1650 merge_check.review_status = review_status
1653
1651
1654 status_approved = review_status == ChangesetStatus.STATUS_APPROVED
1652 status_approved = review_status == ChangesetStatus.STATUS_APPROVED
1655 if not status_approved:
1653 if not status_approved:
1656 log.debug("MergeCheck: cannot merge, approval is pending.")
1654 log.debug("MergeCheck: cannot merge, approval is pending.")
1657
1655
1658 msg = _('Pull request reviewer approval is pending.')
1656 msg = _('Pull request reviewer approval is pending.')
1659
1657
1660 merge_check.push_error(
1658 merge_check.push_error(
1661 'warning', msg, cls.REVIEW_CHECK, review_status)
1659 'warning', msg, cls.REVIEW_CHECK, review_status)
1662
1660
1663 if fail_early:
1661 if fail_early:
1664 return merge_check
1662 return merge_check
1665
1663
1666 # left over TODOs
1664 # left over TODOs
1667 todos = CommentsModel().get_unresolved_todos(pull_request)
1665 todos = CommentsModel().get_unresolved_todos(pull_request)
1668 if todos:
1666 if todos:
1669 log.debug("MergeCheck: cannot merge, {} "
1667 log.debug("MergeCheck: cannot merge, {} "
1670 "unresolved todos left.".format(len(todos)))
1668 "unresolved todos left.".format(len(todos)))
1671
1669
1672 if len(todos) == 1:
1670 if len(todos) == 1:
1673 msg = _('Cannot merge, {} TODO still not resolved.').format(
1671 msg = _('Cannot merge, {} TODO still not resolved.').format(
1674 len(todos))
1672 len(todos))
1675 else:
1673 else:
1676 msg = _('Cannot merge, {} TODOs still not resolved.').format(
1674 msg = _('Cannot merge, {} TODOs still not resolved.').format(
1677 len(todos))
1675 len(todos))
1678
1676
1679 merge_check.push_error('warning', msg, cls.TODO_CHECK, todos)
1677 merge_check.push_error('warning', msg, cls.TODO_CHECK, todos)
1680
1678
1681 if fail_early:
1679 if fail_early:
1682 return merge_check
1680 return merge_check
1683
1681
1684 # merge possible, here is the filesystem simulation + shadow repo
1682 # merge possible, here is the filesystem simulation + shadow repo
1685 merge_status, msg = PullRequestModel().merge_status(
1683 merge_status, msg = PullRequestModel().merge_status(
1686 pull_request, translator=translator,
1684 pull_request, translator=translator,
1687 force_shadow_repo_refresh=force_shadow_repo_refresh)
1685 force_shadow_repo_refresh=force_shadow_repo_refresh)
1688 merge_check.merge_possible = merge_status
1686 merge_check.merge_possible = merge_status
1689 merge_check.merge_msg = msg
1687 merge_check.merge_msg = msg
1690 if not merge_status:
1688 if not merge_status:
1691 log.debug(
1689 log.debug(
1692 "MergeCheck: cannot merge, pull request merge not possible.")
1690 "MergeCheck: cannot merge, pull request merge not possible.")
1693 merge_check.push_error('warning', msg, cls.MERGE_CHECK, None)
1691 merge_check.push_error('warning', msg, cls.MERGE_CHECK, None)
1694
1692
1695 if fail_early:
1693 if fail_early:
1696 return merge_check
1694 return merge_check
1697
1695
1698 log.debug('MergeCheck: is failed: %s', merge_check.failed)
1696 log.debug('MergeCheck: is failed: %s', merge_check.failed)
1699 return merge_check
1697 return merge_check
1700
1698
1701 @classmethod
1699 @classmethod
1702 def get_merge_conditions(cls, pull_request, translator):
1700 def get_merge_conditions(cls, pull_request, translator):
1703 _ = translator
1701 _ = translator
1704 merge_details = {}
1702 merge_details = {}
1705
1703
1706 model = PullRequestModel()
1704 model = PullRequestModel()
1707 use_rebase = model._use_rebase_for_merging(pull_request)
1705 use_rebase = model._use_rebase_for_merging(pull_request)
1708
1706
1709 if use_rebase:
1707 if use_rebase:
1710 merge_details['merge_strategy'] = dict(
1708 merge_details['merge_strategy'] = dict(
1711 details={},
1709 details={},
1712 message=_('Merge strategy: rebase')
1710 message=_('Merge strategy: rebase')
1713 )
1711 )
1714 else:
1712 else:
1715 merge_details['merge_strategy'] = dict(
1713 merge_details['merge_strategy'] = dict(
1716 details={},
1714 details={},
1717 message=_('Merge strategy: explicit merge commit')
1715 message=_('Merge strategy: explicit merge commit')
1718 )
1716 )
1719
1717
1720 close_branch = model._close_branch_before_merging(pull_request)
1718 close_branch = model._close_branch_before_merging(pull_request)
1721 if close_branch:
1719 if close_branch:
1722 repo_type = pull_request.target_repo.repo_type
1720 repo_type = pull_request.target_repo.repo_type
1723 if repo_type == 'hg':
1721 if repo_type == 'hg':
1724 close_msg = _('Source branch will be closed after merge.')
1722 close_msg = _('Source branch will be closed after merge.')
1725 elif repo_type == 'git':
1723 elif repo_type == 'git':
1726 close_msg = _('Source branch will be deleted after merge.')
1724 close_msg = _('Source branch will be deleted after merge.')
1727
1725
1728 merge_details['close_branch'] = dict(
1726 merge_details['close_branch'] = dict(
1729 details={},
1727 details={},
1730 message=close_msg
1728 message=close_msg
1731 )
1729 )
1732
1730
1733 return merge_details
1731 return merge_details
1734
1732
1735 ChangeTuple = collections.namedtuple(
1733 ChangeTuple = collections.namedtuple(
1736 'ChangeTuple', ['added', 'common', 'removed', 'total'])
1734 'ChangeTuple', ['added', 'common', 'removed', 'total'])
1737
1735
1738 FileChangeTuple = collections.namedtuple(
1736 FileChangeTuple = collections.namedtuple(
1739 'FileChangeTuple', ['added', 'modified', 'removed'])
1737 'FileChangeTuple', ['added', 'modified', 'removed'])
@@ -1,838 +1,837 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2018 RhodeCode GmbH
3 # Copyright (C) 2010-2018 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import os
21 import os
22 import hashlib
22 import hashlib
23 import logging
23 import logging
24 from collections import namedtuple
24 from collections import namedtuple
25 from functools import wraps
25 from functools import wraps
26 import bleach
26 import bleach
27
27
28 from rhodecode.lib import rc_cache
28 from rhodecode.lib import rc_cache
29 from rhodecode.lib.utils2 import (
29 from rhodecode.lib.utils2 import (
30 Optional, AttributeDict, safe_str, remove_prefix, str2bool)
30 Optional, AttributeDict, safe_str, remove_prefix, str2bool)
31 from rhodecode.lib.vcs.backends import base
31 from rhodecode.lib.vcs.backends import base
32 from rhodecode.model import BaseModel
32 from rhodecode.model import BaseModel
33 from rhodecode.model.db import (
33 from rhodecode.model.db import (
34 RepoRhodeCodeUi, RepoRhodeCodeSetting, RhodeCodeUi, RhodeCodeSetting, CacheKey)
34 RepoRhodeCodeUi, RepoRhodeCodeSetting, RhodeCodeUi, RhodeCodeSetting, CacheKey)
35 from rhodecode.model.meta import Session
35 from rhodecode.model.meta import Session
36
36
37
37
38 log = logging.getLogger(__name__)
38 log = logging.getLogger(__name__)
39
39
40
40
41 UiSetting = namedtuple(
41 UiSetting = namedtuple(
42 'UiSetting', ['section', 'key', 'value', 'active'])
42 'UiSetting', ['section', 'key', 'value', 'active'])
43
43
44 SOCIAL_PLUGINS_LIST = ['github', 'bitbucket', 'twitter', 'google']
44 SOCIAL_PLUGINS_LIST = ['github', 'bitbucket', 'twitter', 'google']
45
45
46
46
47 class SettingNotFound(Exception):
47 class SettingNotFound(Exception):
48 def __init__(self, setting_id):
48 def __init__(self, setting_id):
49 msg = 'Setting `{}` is not found'.format(setting_id)
49 msg = 'Setting `{}` is not found'.format(setting_id)
50 super(SettingNotFound, self).__init__(msg)
50 super(SettingNotFound, self).__init__(msg)
51
51
52
52
53 class SettingsModel(BaseModel):
53 class SettingsModel(BaseModel):
54 BUILTIN_HOOKS = (
54 BUILTIN_HOOKS = (
55 RhodeCodeUi.HOOK_REPO_SIZE, RhodeCodeUi.HOOK_PUSH,
55 RhodeCodeUi.HOOK_REPO_SIZE, RhodeCodeUi.HOOK_PUSH,
56 RhodeCodeUi.HOOK_PRE_PUSH, RhodeCodeUi.HOOK_PRETX_PUSH,
56 RhodeCodeUi.HOOK_PRE_PUSH, RhodeCodeUi.HOOK_PRETX_PUSH,
57 RhodeCodeUi.HOOK_PULL, RhodeCodeUi.HOOK_PRE_PULL,
57 RhodeCodeUi.HOOK_PULL, RhodeCodeUi.HOOK_PRE_PULL,
58 RhodeCodeUi.HOOK_PUSH_KEY,)
58 RhodeCodeUi.HOOK_PUSH_KEY,)
59 HOOKS_SECTION = 'hooks'
59 HOOKS_SECTION = 'hooks'
60
60
61 def __init__(self, sa=None, repo=None):
61 def __init__(self, sa=None, repo=None):
62 self.repo = repo
62 self.repo = repo
63 self.UiDbModel = RepoRhodeCodeUi if repo else RhodeCodeUi
63 self.UiDbModel = RepoRhodeCodeUi if repo else RhodeCodeUi
64 self.SettingsDbModel = (
64 self.SettingsDbModel = (
65 RepoRhodeCodeSetting if repo else RhodeCodeSetting)
65 RepoRhodeCodeSetting if repo else RhodeCodeSetting)
66 super(SettingsModel, self).__init__(sa)
66 super(SettingsModel, self).__init__(sa)
67
67
68 def get_ui_by_key(self, key):
68 def get_ui_by_key(self, key):
69 q = self.UiDbModel.query()
69 q = self.UiDbModel.query()
70 q = q.filter(self.UiDbModel.ui_key == key)
70 q = q.filter(self.UiDbModel.ui_key == key)
71 q = self._filter_by_repo(RepoRhodeCodeUi, q)
71 q = self._filter_by_repo(RepoRhodeCodeUi, q)
72 return q.scalar()
72 return q.scalar()
73
73
74 def get_ui_by_section(self, section):
74 def get_ui_by_section(self, section):
75 q = self.UiDbModel.query()
75 q = self.UiDbModel.query()
76 q = q.filter(self.UiDbModel.ui_section == section)
76 q = q.filter(self.UiDbModel.ui_section == section)
77 q = self._filter_by_repo(RepoRhodeCodeUi, q)
77 q = self._filter_by_repo(RepoRhodeCodeUi, q)
78 return q.all()
78 return q.all()
79
79
80 def get_ui_by_section_and_key(self, section, key):
80 def get_ui_by_section_and_key(self, section, key):
81 q = self.UiDbModel.query()
81 q = self.UiDbModel.query()
82 q = q.filter(self.UiDbModel.ui_section == section)
82 q = q.filter(self.UiDbModel.ui_section == section)
83 q = q.filter(self.UiDbModel.ui_key == key)
83 q = q.filter(self.UiDbModel.ui_key == key)
84 q = self._filter_by_repo(RepoRhodeCodeUi, q)
84 q = self._filter_by_repo(RepoRhodeCodeUi, q)
85 return q.scalar()
85 return q.scalar()
86
86
87 def get_ui(self, section=None, key=None):
87 def get_ui(self, section=None, key=None):
88 q = self.UiDbModel.query()
88 q = self.UiDbModel.query()
89 q = self._filter_by_repo(RepoRhodeCodeUi, q)
89 q = self._filter_by_repo(RepoRhodeCodeUi, q)
90
90
91 if section:
91 if section:
92 q = q.filter(self.UiDbModel.ui_section == section)
92 q = q.filter(self.UiDbModel.ui_section == section)
93 if key:
93 if key:
94 q = q.filter(self.UiDbModel.ui_key == key)
94 q = q.filter(self.UiDbModel.ui_key == key)
95
95
96 # TODO: mikhail: add caching
96 # TODO: mikhail: add caching
97 result = [
97 result = [
98 UiSetting(
98 UiSetting(
99 section=safe_str(r.ui_section), key=safe_str(r.ui_key),
99 section=safe_str(r.ui_section), key=safe_str(r.ui_key),
100 value=safe_str(r.ui_value), active=r.ui_active
100 value=safe_str(r.ui_value), active=r.ui_active
101 )
101 )
102 for r in q.all()
102 for r in q.all()
103 ]
103 ]
104 return result
104 return result
105
105
106 def get_builtin_hooks(self):
106 def get_builtin_hooks(self):
107 q = self.UiDbModel.query()
107 q = self.UiDbModel.query()
108 q = q.filter(self.UiDbModel.ui_key.in_(self.BUILTIN_HOOKS))
108 q = q.filter(self.UiDbModel.ui_key.in_(self.BUILTIN_HOOKS))
109 return self._get_hooks(q)
109 return self._get_hooks(q)
110
110
111 def get_custom_hooks(self):
111 def get_custom_hooks(self):
112 q = self.UiDbModel.query()
112 q = self.UiDbModel.query()
113 q = q.filter(~self.UiDbModel.ui_key.in_(self.BUILTIN_HOOKS))
113 q = q.filter(~self.UiDbModel.ui_key.in_(self.BUILTIN_HOOKS))
114 return self._get_hooks(q)
114 return self._get_hooks(q)
115
115
116 def create_ui_section_value(self, section, val, key=None, active=True):
116 def create_ui_section_value(self, section, val, key=None, active=True):
117 new_ui = self.UiDbModel()
117 new_ui = self.UiDbModel()
118 new_ui.ui_section = section
118 new_ui.ui_section = section
119 new_ui.ui_value = val
119 new_ui.ui_value = val
120 new_ui.ui_active = active
120 new_ui.ui_active = active
121
121
122 if self.repo:
122 if self.repo:
123 repo = self._get_repo(self.repo)
123 repo = self._get_repo(self.repo)
124 repository_id = repo.repo_id
124 repository_id = repo.repo_id
125 new_ui.repository_id = repository_id
125 new_ui.repository_id = repository_id
126
126
127 if not key:
127 if not key:
128 # keys are unique so they need appended info
128 # keys are unique so they need appended info
129 if self.repo:
129 if self.repo:
130 key = hashlib.sha1(
130 key = hashlib.sha1(
131 '{}{}{}'.format(section, val, repository_id)).hexdigest()
131 '{}{}{}'.format(section, val, repository_id)).hexdigest()
132 else:
132 else:
133 key = hashlib.sha1('{}{}'.format(section, val)).hexdigest()
133 key = hashlib.sha1('{}{}'.format(section, val)).hexdigest()
134
134
135 new_ui.ui_key = key
135 new_ui.ui_key = key
136
136
137 Session().add(new_ui)
137 Session().add(new_ui)
138 return new_ui
138 return new_ui
139
139
140 def create_or_update_hook(self, key, value):
140 def create_or_update_hook(self, key, value):
141 ui = (
141 ui = (
142 self.get_ui_by_section_and_key(self.HOOKS_SECTION, key) or
142 self.get_ui_by_section_and_key(self.HOOKS_SECTION, key) or
143 self.UiDbModel())
143 self.UiDbModel())
144 ui.ui_section = self.HOOKS_SECTION
144 ui.ui_section = self.HOOKS_SECTION
145 ui.ui_active = True
145 ui.ui_active = True
146 ui.ui_key = key
146 ui.ui_key = key
147 ui.ui_value = value
147 ui.ui_value = value
148
148
149 if self.repo:
149 if self.repo:
150 repo = self._get_repo(self.repo)
150 repo = self._get_repo(self.repo)
151 repository_id = repo.repo_id
151 repository_id = repo.repo_id
152 ui.repository_id = repository_id
152 ui.repository_id = repository_id
153
153
154 Session().add(ui)
154 Session().add(ui)
155 return ui
155 return ui
156
156
157 def delete_ui(self, id_):
157 def delete_ui(self, id_):
158 ui = self.UiDbModel.get(id_)
158 ui = self.UiDbModel.get(id_)
159 if not ui:
159 if not ui:
160 raise SettingNotFound(id_)
160 raise SettingNotFound(id_)
161 Session().delete(ui)
161 Session().delete(ui)
162
162
163 def get_setting_by_name(self, name):
163 def get_setting_by_name(self, name):
164 q = self._get_settings_query()
164 q = self._get_settings_query()
165 q = q.filter(self.SettingsDbModel.app_settings_name == name)
165 q = q.filter(self.SettingsDbModel.app_settings_name == name)
166 return q.scalar()
166 return q.scalar()
167
167
168 def create_or_update_setting(
168 def create_or_update_setting(
169 self, name, val=Optional(''), type_=Optional('unicode')):
169 self, name, val=Optional(''), type_=Optional('unicode')):
170 """
170 """
171 Creates or updates RhodeCode setting. If updates is triggered it will
171 Creates or updates RhodeCode setting. If updates is triggered it will
172 only update parameters that are explicityl set Optional instance will
172 only update parameters that are explicityl set Optional instance will
173 be skipped
173 be skipped
174
174
175 :param name:
175 :param name:
176 :param val:
176 :param val:
177 :param type_:
177 :param type_:
178 :return:
178 :return:
179 """
179 """
180
180
181 res = self.get_setting_by_name(name)
181 res = self.get_setting_by_name(name)
182 repo = self._get_repo(self.repo) if self.repo else None
182 repo = self._get_repo(self.repo) if self.repo else None
183
183
184 if not res:
184 if not res:
185 val = Optional.extract(val)
185 val = Optional.extract(val)
186 type_ = Optional.extract(type_)
186 type_ = Optional.extract(type_)
187
187
188 args = (
188 args = (
189 (repo.repo_id, name, val, type_)
189 (repo.repo_id, name, val, type_)
190 if repo else (name, val, type_))
190 if repo else (name, val, type_))
191 res = self.SettingsDbModel(*args)
191 res = self.SettingsDbModel(*args)
192
192
193 else:
193 else:
194 if self.repo:
194 if self.repo:
195 res.repository_id = repo.repo_id
195 res.repository_id = repo.repo_id
196
196
197 res.app_settings_name = name
197 res.app_settings_name = name
198 if not isinstance(type_, Optional):
198 if not isinstance(type_, Optional):
199 # update if set
199 # update if set
200 res.app_settings_type = type_
200 res.app_settings_type = type_
201 if not isinstance(val, Optional):
201 if not isinstance(val, Optional):
202 # update if set
202 # update if set
203 res.app_settings_value = val
203 res.app_settings_value = val
204
204
205 Session().add(res)
205 Session().add(res)
206 return res
206 return res
207
207
208 def invalidate_settings_cache(self):
208 def invalidate_settings_cache(self):
209 invalidation_namespace = CacheKey.SETTINGS_INVALIDATION_NAMESPACE
209 invalidation_namespace = CacheKey.SETTINGS_INVALIDATION_NAMESPACE
210 CacheKey.set_invalidate(invalidation_namespace)
210 CacheKey.set_invalidate(invalidation_namespace)
211
211
212 def get_all_settings(self, cache=False):
212 def get_all_settings(self, cache=False):
213 region = rc_cache.get_or_create_region('sql_cache_short')
213 region = rc_cache.get_or_create_region('sql_cache_short')
214 invalidation_namespace = CacheKey.SETTINGS_INVALIDATION_NAMESPACE
214 invalidation_namespace = CacheKey.SETTINGS_INVALIDATION_NAMESPACE
215
215
216 @region.conditional_cache_on_arguments(condition=cache)
216 @region.conditional_cache_on_arguments(condition=cache)
217 def _get_all_settings(name, key):
217 def _get_all_settings(name, key):
218 q = self._get_settings_query()
218 q = self._get_settings_query()
219 if not q:
219 if not q:
220 raise Exception('Could not get application settings !')
220 raise Exception('Could not get application settings !')
221
221
222 settings = {
222 settings = {
223 'rhodecode_' + result.app_settings_name: result.app_settings_value
223 'rhodecode_' + result.app_settings_name: result.app_settings_value
224 for result in q
224 for result in q
225 }
225 }
226 return settings
226 return settings
227
227
228 repo = self._get_repo(self.repo) if self.repo else None
228 repo = self._get_repo(self.repo) if self.repo else None
229 key = "settings_repo.{}".format(repo.repo_id) if repo else "settings_app"
229 key = "settings_repo.{}".format(repo.repo_id) if repo else "settings_app"
230
230
231 inv_context_manager = rc_cache.InvalidationContext(
231 inv_context_manager = rc_cache.InvalidationContext(
232 uid='cache_settings', invalidation_namespace=invalidation_namespace)
232 uid='cache_settings', invalidation_namespace=invalidation_namespace)
233 with inv_context_manager as invalidation_context:
233 with inv_context_manager as invalidation_context:
234 # check for stored invalidation signal, and maybe purge the cache
234 # check for stored invalidation signal, and maybe purge the cache
235 # before computing it again
235 # before computing it again
236 if invalidation_context.should_invalidate():
236 if invalidation_context.should_invalidate():
237 # NOTE:(marcink) we flush the whole sql_cache_short region, because it
237 # NOTE:(marcink) we flush the whole sql_cache_short region, because it
238 # reads different settings etc. It's little too much but those caches
238 # reads different settings etc. It's little too much but those caches
239 # are anyway very short lived and it's a safest way.
239 # are anyway very short lived and it's a safest way.
240 region = rc_cache.get_or_create_region('sql_cache_short')
240 region = rc_cache.get_or_create_region('sql_cache_short')
241 region.invalidate()
241 region.invalidate()
242
242
243 result = _get_all_settings('rhodecode_settings', key)
243 result = _get_all_settings('rhodecode_settings', key)
244 log.debug('Fetching app settings for key: %s took: %.3fs', key,
244 log.debug('Fetching app settings for key: %s took: %.3fs', key,
245 inv_context_manager.compute_time)
245 inv_context_manager.compute_time)
246
246
247 return result
247 return result
248
248
249 def get_auth_settings(self):
249 def get_auth_settings(self):
250 q = self._get_settings_query()
250 q = self._get_settings_query()
251 q = q.filter(
251 q = q.filter(
252 self.SettingsDbModel.app_settings_name.startswith('auth_'))
252 self.SettingsDbModel.app_settings_name.startswith('auth_'))
253 rows = q.all()
253 rows = q.all()
254 auth_settings = {
254 auth_settings = {
255 row.app_settings_name: row.app_settings_value for row in rows}
255 row.app_settings_name: row.app_settings_value for row in rows}
256 return auth_settings
256 return auth_settings
257
257
258 def get_auth_plugins(self):
258 def get_auth_plugins(self):
259 auth_plugins = self.get_setting_by_name("auth_plugins")
259 auth_plugins = self.get_setting_by_name("auth_plugins")
260 return auth_plugins.app_settings_value
260 return auth_plugins.app_settings_value
261
261
262 def get_default_repo_settings(self, strip_prefix=False):
262 def get_default_repo_settings(self, strip_prefix=False):
263 q = self._get_settings_query()
263 q = self._get_settings_query()
264 q = q.filter(
264 q = q.filter(
265 self.SettingsDbModel.app_settings_name.startswith('default_'))
265 self.SettingsDbModel.app_settings_name.startswith('default_'))
266 rows = q.all()
266 rows = q.all()
267
267
268 result = {}
268 result = {}
269 for row in rows:
269 for row in rows:
270 key = row.app_settings_name
270 key = row.app_settings_name
271 if strip_prefix:
271 if strip_prefix:
272 key = remove_prefix(key, prefix='default_')
272 key = remove_prefix(key, prefix='default_')
273 result.update({key: row.app_settings_value})
273 result.update({key: row.app_settings_value})
274 return result
274 return result
275
275
276 def get_repo(self):
276 def get_repo(self):
277 repo = self._get_repo(self.repo)
277 repo = self._get_repo(self.repo)
278 if not repo:
278 if not repo:
279 raise Exception(
279 raise Exception(
280 'Repository `{}` cannot be found inside the database'.format(
280 'Repository `{}` cannot be found inside the database'.format(
281 self.repo))
281 self.repo))
282 return repo
282 return repo
283
283
284 def _filter_by_repo(self, model, query):
284 def _filter_by_repo(self, model, query):
285 if self.repo:
285 if self.repo:
286 repo = self.get_repo()
286 repo = self.get_repo()
287 query = query.filter(model.repository_id == repo.repo_id)
287 query = query.filter(model.repository_id == repo.repo_id)
288 return query
288 return query
289
289
290 def _get_hooks(self, query):
290 def _get_hooks(self, query):
291 query = query.filter(self.UiDbModel.ui_section == self.HOOKS_SECTION)
291 query = query.filter(self.UiDbModel.ui_section == self.HOOKS_SECTION)
292 query = self._filter_by_repo(RepoRhodeCodeUi, query)
292 query = self._filter_by_repo(RepoRhodeCodeUi, query)
293 return query.all()
293 return query.all()
294
294
295 def _get_settings_query(self):
295 def _get_settings_query(self):
296 q = self.SettingsDbModel.query()
296 q = self.SettingsDbModel.query()
297 return self._filter_by_repo(RepoRhodeCodeSetting, q)
297 return self._filter_by_repo(RepoRhodeCodeSetting, q)
298
298
299 def list_enabled_social_plugins(self, settings):
299 def list_enabled_social_plugins(self, settings):
300 enabled = []
300 enabled = []
301 for plug in SOCIAL_PLUGINS_LIST:
301 for plug in SOCIAL_PLUGINS_LIST:
302 if str2bool(settings.get('rhodecode_auth_{}_enabled'.format(plug)
302 if str2bool(settings.get('rhodecode_auth_{}_enabled'.format(plug)
303 )):
303 )):
304 enabled.append(plug)
304 enabled.append(plug)
305 return enabled
305 return enabled
306
306
307
307
308 def assert_repo_settings(func):
308 def assert_repo_settings(func):
309 @wraps(func)
309 @wraps(func)
310 def _wrapper(self, *args, **kwargs):
310 def _wrapper(self, *args, **kwargs):
311 if not self.repo_settings:
311 if not self.repo_settings:
312 raise Exception('Repository is not specified')
312 raise Exception('Repository is not specified')
313 return func(self, *args, **kwargs)
313 return func(self, *args, **kwargs)
314 return _wrapper
314 return _wrapper
315
315
316
316
317 class IssueTrackerSettingsModel(object):
317 class IssueTrackerSettingsModel(object):
318 INHERIT_SETTINGS = 'inherit_issue_tracker_settings'
318 INHERIT_SETTINGS = 'inherit_issue_tracker_settings'
319 SETTINGS_PREFIX = 'issuetracker_'
319 SETTINGS_PREFIX = 'issuetracker_'
320
320
321 def __init__(self, sa=None, repo=None):
321 def __init__(self, sa=None, repo=None):
322 self.global_settings = SettingsModel(sa=sa)
322 self.global_settings = SettingsModel(sa=sa)
323 self.repo_settings = SettingsModel(sa=sa, repo=repo) if repo else None
323 self.repo_settings = SettingsModel(sa=sa, repo=repo) if repo else None
324
324
325 @property
325 @property
326 def inherit_global_settings(self):
326 def inherit_global_settings(self):
327 if not self.repo_settings:
327 if not self.repo_settings:
328 return True
328 return True
329 setting = self.repo_settings.get_setting_by_name(self.INHERIT_SETTINGS)
329 setting = self.repo_settings.get_setting_by_name(self.INHERIT_SETTINGS)
330 return setting.app_settings_value if setting else True
330 return setting.app_settings_value if setting else True
331
331
332 @inherit_global_settings.setter
332 @inherit_global_settings.setter
333 def inherit_global_settings(self, value):
333 def inherit_global_settings(self, value):
334 if self.repo_settings:
334 if self.repo_settings:
335 settings = self.repo_settings.create_or_update_setting(
335 settings = self.repo_settings.create_or_update_setting(
336 self.INHERIT_SETTINGS, value, type_='bool')
336 self.INHERIT_SETTINGS, value, type_='bool')
337 Session().add(settings)
337 Session().add(settings)
338
338
339 def _get_keyname(self, key, uid, prefix=''):
339 def _get_keyname(self, key, uid, prefix=''):
340 return '{0}{1}{2}_{3}'.format(
340 return '{0}{1}{2}_{3}'.format(
341 prefix, self.SETTINGS_PREFIX, key, uid)
341 prefix, self.SETTINGS_PREFIX, key, uid)
342
342
343 def _make_dict_for_settings(self, qs):
343 def _make_dict_for_settings(self, qs):
344 prefix_match = self._get_keyname('pat', '', 'rhodecode_')
344 prefix_match = self._get_keyname('pat', '', 'rhodecode_')
345
345
346 issuetracker_entries = {}
346 issuetracker_entries = {}
347 # create keys
347 # create keys
348 for k, v in qs.items():
348 for k, v in qs.items():
349 if k.startswith(prefix_match):
349 if k.startswith(prefix_match):
350 uid = k[len(prefix_match):]
350 uid = k[len(prefix_match):]
351 issuetracker_entries[uid] = None
351 issuetracker_entries[uid] = None
352
352
353 # populate
353 # populate
354 for uid in issuetracker_entries:
354 for uid in issuetracker_entries:
355 issuetracker_entries[uid] = AttributeDict({
355 issuetracker_entries[uid] = AttributeDict({
356 'pat': qs.get(
356 'pat': qs.get(
357 self._get_keyname('pat', uid, 'rhodecode_')),
357 self._get_keyname('pat', uid, 'rhodecode_')),
358 'url': bleach.clean(
358 'url': bleach.clean(
359 qs.get(self._get_keyname('url', uid, 'rhodecode_')) or ''),
359 qs.get(self._get_keyname('url', uid, 'rhodecode_')) or ''),
360 'pref': bleach.clean(
360 'pref': bleach.clean(
361 qs.get(self._get_keyname('pref', uid, 'rhodecode_')) or ''),
361 qs.get(self._get_keyname('pref', uid, 'rhodecode_')) or ''),
362 'desc': qs.get(
362 'desc': qs.get(
363 self._get_keyname('desc', uid, 'rhodecode_')),
363 self._get_keyname('desc', uid, 'rhodecode_')),
364 })
364 })
365 return issuetracker_entries
365 return issuetracker_entries
366
366
367 def get_global_settings(self, cache=False):
367 def get_global_settings(self, cache=False):
368 """
368 """
369 Returns list of global issue tracker settings
369 Returns list of global issue tracker settings
370 """
370 """
371 defaults = self.global_settings.get_all_settings(cache=cache)
371 defaults = self.global_settings.get_all_settings(cache=cache)
372 settings = self._make_dict_for_settings(defaults)
372 settings = self._make_dict_for_settings(defaults)
373 return settings
373 return settings
374
374
375 def get_repo_settings(self, cache=False):
375 def get_repo_settings(self, cache=False):
376 """
376 """
377 Returns list of issue tracker settings per repository
377 Returns list of issue tracker settings per repository
378 """
378 """
379 if not self.repo_settings:
379 if not self.repo_settings:
380 raise Exception('Repository is not specified')
380 raise Exception('Repository is not specified')
381 all_settings = self.repo_settings.get_all_settings(cache=cache)
381 all_settings = self.repo_settings.get_all_settings(cache=cache)
382 settings = self._make_dict_for_settings(all_settings)
382 settings = self._make_dict_for_settings(all_settings)
383 return settings
383 return settings
384
384
385 def get_settings(self, cache=False):
385 def get_settings(self, cache=False):
386 if self.inherit_global_settings:
386 if self.inherit_global_settings:
387 return self.get_global_settings(cache=cache)
387 return self.get_global_settings(cache=cache)
388 else:
388 else:
389 return self.get_repo_settings(cache=cache)
389 return self.get_repo_settings(cache=cache)
390
390
391 def delete_entries(self, uid):
391 def delete_entries(self, uid):
392 if self.repo_settings:
392 if self.repo_settings:
393 all_patterns = self.get_repo_settings()
393 all_patterns = self.get_repo_settings()
394 settings_model = self.repo_settings
394 settings_model = self.repo_settings
395 else:
395 else:
396 all_patterns = self.get_global_settings()
396 all_patterns = self.get_global_settings()
397 settings_model = self.global_settings
397 settings_model = self.global_settings
398 entries = all_patterns.get(uid, [])
398 entries = all_patterns.get(uid, [])
399
399
400 for del_key in entries:
400 for del_key in entries:
401 setting_name = self._get_keyname(del_key, uid)
401 setting_name = self._get_keyname(del_key, uid)
402 entry = settings_model.get_setting_by_name(setting_name)
402 entry = settings_model.get_setting_by_name(setting_name)
403 if entry:
403 if entry:
404 Session().delete(entry)
404 Session().delete(entry)
405
405
406 Session().commit()
406 Session().commit()
407
407
408 def create_or_update_setting(
408 def create_or_update_setting(
409 self, name, val=Optional(''), type_=Optional('unicode')):
409 self, name, val=Optional(''), type_=Optional('unicode')):
410 if self.repo_settings:
410 if self.repo_settings:
411 setting = self.repo_settings.create_or_update_setting(
411 setting = self.repo_settings.create_or_update_setting(
412 name, val, type_)
412 name, val, type_)
413 else:
413 else:
414 setting = self.global_settings.create_or_update_setting(
414 setting = self.global_settings.create_or_update_setting(
415 name, val, type_)
415 name, val, type_)
416 return setting
416 return setting
417
417
418
418
419 class VcsSettingsModel(object):
419 class VcsSettingsModel(object):
420
420
421 INHERIT_SETTINGS = 'inherit_vcs_settings'
421 INHERIT_SETTINGS = 'inherit_vcs_settings'
422 GENERAL_SETTINGS = (
422 GENERAL_SETTINGS = (
423 'use_outdated_comments',
423 'use_outdated_comments',
424 'pr_merge_enabled',
424 'pr_merge_enabled',
425 'hg_use_rebase_for_merging',
425 'hg_use_rebase_for_merging',
426 'hg_close_branch_before_merging',
426 'hg_close_branch_before_merging',
427 'git_use_rebase_for_merging',
427 'git_use_rebase_for_merging',
428 'git_close_branch_before_merging',
428 'git_close_branch_before_merging',
429 'diff_cache',
429 'diff_cache',
430 )
430 )
431
431
432 HOOKS_SETTINGS = (
432 HOOKS_SETTINGS = (
433 ('hooks', 'changegroup.repo_size'),
433 ('hooks', 'changegroup.repo_size'),
434 ('hooks', 'changegroup.push_logger'),
434 ('hooks', 'changegroup.push_logger'),
435 ('hooks', 'outgoing.pull_logger'),)
435 ('hooks', 'outgoing.pull_logger'),)
436 HG_SETTINGS = (
436 HG_SETTINGS = (
437 ('extensions', 'largefiles'),
437 ('extensions', 'largefiles'),
438 ('phases', 'publish'),
438 ('phases', 'publish'),
439 ('extensions', 'evolve'),)
439 ('extensions', 'evolve'),)
440 GIT_SETTINGS = (
440 GIT_SETTINGS = (
441 ('vcs_git_lfs', 'enabled'),)
441 ('vcs_git_lfs', 'enabled'),)
442 GLOBAL_HG_SETTINGS = (
442 GLOBAL_HG_SETTINGS = (
443 ('extensions', 'largefiles'),
443 ('extensions', 'largefiles'),
444 ('largefiles', 'usercache'),
444 ('largefiles', 'usercache'),
445 ('phases', 'publish'),
445 ('phases', 'publish'),
446 ('extensions', 'hgsubversion'),
446 ('extensions', 'hgsubversion'),
447 ('extensions', 'evolve'),)
447 ('extensions', 'evolve'),)
448 GLOBAL_GIT_SETTINGS = (
448 GLOBAL_GIT_SETTINGS = (
449 ('vcs_git_lfs', 'enabled'),
449 ('vcs_git_lfs', 'enabled'),
450 ('vcs_git_lfs', 'store_location'))
450 ('vcs_git_lfs', 'store_location'))
451
451
452 GLOBAL_SVN_SETTINGS = (
452 GLOBAL_SVN_SETTINGS = (
453 ('vcs_svn_proxy', 'http_requests_enabled'),
453 ('vcs_svn_proxy', 'http_requests_enabled'),
454 ('vcs_svn_proxy', 'http_server_url'))
454 ('vcs_svn_proxy', 'http_server_url'))
455
455
456 SVN_BRANCH_SECTION = 'vcs_svn_branch'
456 SVN_BRANCH_SECTION = 'vcs_svn_branch'
457 SVN_TAG_SECTION = 'vcs_svn_tag'
457 SVN_TAG_SECTION = 'vcs_svn_tag'
458 SSL_SETTING = ('web', 'push_ssl')
458 SSL_SETTING = ('web', 'push_ssl')
459 PATH_SETTING = ('paths', '/')
459 PATH_SETTING = ('paths', '/')
460
460
461 def __init__(self, sa=None, repo=None):
461 def __init__(self, sa=None, repo=None):
462 self.global_settings = SettingsModel(sa=sa)
462 self.global_settings = SettingsModel(sa=sa)
463 self.repo_settings = SettingsModel(sa=sa, repo=repo) if repo else None
463 self.repo_settings = SettingsModel(sa=sa, repo=repo) if repo else None
464 self._ui_settings = (
464 self._ui_settings = (
465 self.HG_SETTINGS + self.GIT_SETTINGS + self.HOOKS_SETTINGS)
465 self.HG_SETTINGS + self.GIT_SETTINGS + self.HOOKS_SETTINGS)
466 self._svn_sections = (self.SVN_BRANCH_SECTION, self.SVN_TAG_SECTION)
466 self._svn_sections = (self.SVN_BRANCH_SECTION, self.SVN_TAG_SECTION)
467
467
468 @property
468 @property
469 @assert_repo_settings
469 @assert_repo_settings
470 def inherit_global_settings(self):
470 def inherit_global_settings(self):
471 setting = self.repo_settings.get_setting_by_name(self.INHERIT_SETTINGS)
471 setting = self.repo_settings.get_setting_by_name(self.INHERIT_SETTINGS)
472 return setting.app_settings_value if setting else True
472 return setting.app_settings_value if setting else True
473
473
474 @inherit_global_settings.setter
474 @inherit_global_settings.setter
475 @assert_repo_settings
475 @assert_repo_settings
476 def inherit_global_settings(self, value):
476 def inherit_global_settings(self, value):
477 self.repo_settings.create_or_update_setting(
477 self.repo_settings.create_or_update_setting(
478 self.INHERIT_SETTINGS, value, type_='bool')
478 self.INHERIT_SETTINGS, value, type_='bool')
479
479
480 def get_global_svn_branch_patterns(self):
480 def get_global_svn_branch_patterns(self):
481 return self.global_settings.get_ui_by_section(self.SVN_BRANCH_SECTION)
481 return self.global_settings.get_ui_by_section(self.SVN_BRANCH_SECTION)
482
482
483 @assert_repo_settings
483 @assert_repo_settings
484 def get_repo_svn_branch_patterns(self):
484 def get_repo_svn_branch_patterns(self):
485 return self.repo_settings.get_ui_by_section(self.SVN_BRANCH_SECTION)
485 return self.repo_settings.get_ui_by_section(self.SVN_BRANCH_SECTION)
486
486
487 def get_global_svn_tag_patterns(self):
487 def get_global_svn_tag_patterns(self):
488 return self.global_settings.get_ui_by_section(self.SVN_TAG_SECTION)
488 return self.global_settings.get_ui_by_section(self.SVN_TAG_SECTION)
489
489
490 @assert_repo_settings
490 @assert_repo_settings
491 def get_repo_svn_tag_patterns(self):
491 def get_repo_svn_tag_patterns(self):
492 return self.repo_settings.get_ui_by_section(self.SVN_TAG_SECTION)
492 return self.repo_settings.get_ui_by_section(self.SVN_TAG_SECTION)
493
493
494 def get_global_settings(self):
494 def get_global_settings(self):
495 return self._collect_all_settings(global_=True)
495 return self._collect_all_settings(global_=True)
496
496
497 @assert_repo_settings
497 @assert_repo_settings
498 def get_repo_settings(self):
498 def get_repo_settings(self):
499 return self._collect_all_settings(global_=False)
499 return self._collect_all_settings(global_=False)
500
500
501 @assert_repo_settings
501 @assert_repo_settings
502 def create_or_update_repo_settings(
502 def create_or_update_repo_settings(
503 self, data, inherit_global_settings=False):
503 self, data, inherit_global_settings=False):
504 from rhodecode.model.scm import ScmModel
504 from rhodecode.model.scm import ScmModel
505
505
506 self.inherit_global_settings = inherit_global_settings
506 self.inherit_global_settings = inherit_global_settings
507
507
508 repo = self.repo_settings.get_repo()
508 repo = self.repo_settings.get_repo()
509 if not inherit_global_settings:
509 if not inherit_global_settings:
510 if repo.repo_type == 'svn':
510 if repo.repo_type == 'svn':
511 self.create_repo_svn_settings(data)
511 self.create_repo_svn_settings(data)
512 else:
512 else:
513 self.create_or_update_repo_hook_settings(data)
513 self.create_or_update_repo_hook_settings(data)
514 self.create_or_update_repo_pr_settings(data)
514 self.create_or_update_repo_pr_settings(data)
515
515
516 if repo.repo_type == 'hg':
516 if repo.repo_type == 'hg':
517 self.create_or_update_repo_hg_settings(data)
517 self.create_or_update_repo_hg_settings(data)
518
518
519 if repo.repo_type == 'git':
519 if repo.repo_type == 'git':
520 self.create_or_update_repo_git_settings(data)
520 self.create_or_update_repo_git_settings(data)
521
521
522 ScmModel().mark_for_invalidation(repo.repo_name, delete=True)
522 ScmModel().mark_for_invalidation(repo.repo_name, delete=True)
523
523
524 @assert_repo_settings
524 @assert_repo_settings
525 def create_or_update_repo_hook_settings(self, data):
525 def create_or_update_repo_hook_settings(self, data):
526 for section, key in self.HOOKS_SETTINGS:
526 for section, key in self.HOOKS_SETTINGS:
527 data_key = self._get_form_ui_key(section, key)
527 data_key = self._get_form_ui_key(section, key)
528 if data_key not in data:
528 if data_key not in data:
529 raise ValueError(
529 raise ValueError(
530 'The given data does not contain {} key'.format(data_key))
530 'The given data does not contain {} key'.format(data_key))
531
531
532 active = data.get(data_key)
532 active = data.get(data_key)
533 repo_setting = self.repo_settings.get_ui_by_section_and_key(
533 repo_setting = self.repo_settings.get_ui_by_section_and_key(
534 section, key)
534 section, key)
535 if not repo_setting:
535 if not repo_setting:
536 global_setting = self.global_settings.\
536 global_setting = self.global_settings.\
537 get_ui_by_section_and_key(section, key)
537 get_ui_by_section_and_key(section, key)
538 self.repo_settings.create_ui_section_value(
538 self.repo_settings.create_ui_section_value(
539 section, global_setting.ui_value, key=key, active=active)
539 section, global_setting.ui_value, key=key, active=active)
540 else:
540 else:
541 repo_setting.ui_active = active
541 repo_setting.ui_active = active
542 Session().add(repo_setting)
542 Session().add(repo_setting)
543
543
544 def update_global_hook_settings(self, data):
544 def update_global_hook_settings(self, data):
545 for section, key in self.HOOKS_SETTINGS:
545 for section, key in self.HOOKS_SETTINGS:
546 data_key = self._get_form_ui_key(section, key)
546 data_key = self._get_form_ui_key(section, key)
547 if data_key not in data:
547 if data_key not in data:
548 raise ValueError(
548 raise ValueError(
549 'The given data does not contain {} key'.format(data_key))
549 'The given data does not contain {} key'.format(data_key))
550 active = data.get(data_key)
550 active = data.get(data_key)
551 repo_setting = self.global_settings.get_ui_by_section_and_key(
551 repo_setting = self.global_settings.get_ui_by_section_and_key(
552 section, key)
552 section, key)
553 repo_setting.ui_active = active
553 repo_setting.ui_active = active
554 Session().add(repo_setting)
554 Session().add(repo_setting)
555
555
556 @assert_repo_settings
556 @assert_repo_settings
557 def create_or_update_repo_pr_settings(self, data):
557 def create_or_update_repo_pr_settings(self, data):
558 return self._create_or_update_general_settings(
558 return self._create_or_update_general_settings(
559 self.repo_settings, data)
559 self.repo_settings, data)
560
560
561 def create_or_update_global_pr_settings(self, data):
561 def create_or_update_global_pr_settings(self, data):
562 return self._create_or_update_general_settings(
562 return self._create_or_update_general_settings(
563 self.global_settings, data)
563 self.global_settings, data)
564
564
565 @assert_repo_settings
565 @assert_repo_settings
566 def create_repo_svn_settings(self, data):
566 def create_repo_svn_settings(self, data):
567 return self._create_svn_settings(self.repo_settings, data)
567 return self._create_svn_settings(self.repo_settings, data)
568
568
569 @assert_repo_settings
569 @assert_repo_settings
570 def create_or_update_repo_hg_settings(self, data):
570 def create_or_update_repo_hg_settings(self, data):
571 largefiles, phases, evolve = \
571 largefiles, phases, evolve = \
572 self.HG_SETTINGS
572 self.HG_SETTINGS
573 largefiles_key, phases_key, evolve_key = \
573 largefiles_key, phases_key, evolve_key = \
574 self._get_settings_keys(self.HG_SETTINGS, data)
574 self._get_settings_keys(self.HG_SETTINGS, data)
575
575
576 self._create_or_update_ui(
576 self._create_or_update_ui(
577 self.repo_settings, *largefiles, value='',
577 self.repo_settings, *largefiles, value='',
578 active=data[largefiles_key])
578 active=data[largefiles_key])
579 self._create_or_update_ui(
579 self._create_or_update_ui(
580 self.repo_settings, *evolve, value='',
580 self.repo_settings, *evolve, value='',
581 active=data[evolve_key])
581 active=data[evolve_key])
582 self._create_or_update_ui(
582 self._create_or_update_ui(
583 self.repo_settings, *phases, value=safe_str(data[phases_key]))
583 self.repo_settings, *phases, value=safe_str(data[phases_key]))
584
584
585
586 def create_or_update_global_hg_settings(self, data):
585 def create_or_update_global_hg_settings(self, data):
587 largefiles, largefiles_store, phases, hgsubversion, evolve \
586 largefiles, largefiles_store, phases, hgsubversion, evolve \
588 = self.GLOBAL_HG_SETTINGS
587 = self.GLOBAL_HG_SETTINGS
589 largefiles_key, largefiles_store_key, phases_key, subversion_key, evolve_key \
588 largefiles_key, largefiles_store_key, phases_key, subversion_key, evolve_key \
590 = self._get_settings_keys(self.GLOBAL_HG_SETTINGS, data)
589 = self._get_settings_keys(self.GLOBAL_HG_SETTINGS, data)
591
590
592 self._create_or_update_ui(
591 self._create_or_update_ui(
593 self.global_settings, *largefiles, value='',
592 self.global_settings, *largefiles, value='',
594 active=data[largefiles_key])
593 active=data[largefiles_key])
595 self._create_or_update_ui(
594 self._create_or_update_ui(
596 self.global_settings, *largefiles_store,
595 self.global_settings, *largefiles_store,
597 value=data[largefiles_store_key])
596 value=data[largefiles_store_key])
598 self._create_or_update_ui(
597 self._create_or_update_ui(
599 self.global_settings, *phases, value=safe_str(data[phases_key]))
598 self.global_settings, *phases, value=safe_str(data[phases_key]))
600 self._create_or_update_ui(
599 self._create_or_update_ui(
601 self.global_settings, *hgsubversion, active=data[subversion_key])
600 self.global_settings, *hgsubversion, active=data[subversion_key])
602 self._create_or_update_ui(
601 self._create_or_update_ui(
603 self.global_settings, *evolve, value='',
602 self.global_settings, *evolve, value='',
604 active=data[evolve_key])
603 active=data[evolve_key])
605
604
606 def create_or_update_repo_git_settings(self, data):
605 def create_or_update_repo_git_settings(self, data):
607 # NOTE(marcink): # comma make unpack work properly
606 # NOTE(marcink): # comma make unpack work properly
608 lfs_enabled, \
607 lfs_enabled, \
609 = self.GIT_SETTINGS
608 = self.GIT_SETTINGS
610
609
611 lfs_enabled_key, \
610 lfs_enabled_key, \
612 = self._get_settings_keys(self.GIT_SETTINGS, data)
611 = self._get_settings_keys(self.GIT_SETTINGS, data)
613
612
614 self._create_or_update_ui(
613 self._create_or_update_ui(
615 self.repo_settings, *lfs_enabled, value=data[lfs_enabled_key],
614 self.repo_settings, *lfs_enabled, value=data[lfs_enabled_key],
616 active=data[lfs_enabled_key])
615 active=data[lfs_enabled_key])
617
616
618 def create_or_update_global_git_settings(self, data):
617 def create_or_update_global_git_settings(self, data):
619 lfs_enabled, lfs_store_location \
618 lfs_enabled, lfs_store_location \
620 = self.GLOBAL_GIT_SETTINGS
619 = self.GLOBAL_GIT_SETTINGS
621 lfs_enabled_key, lfs_store_location_key \
620 lfs_enabled_key, lfs_store_location_key \
622 = self._get_settings_keys(self.GLOBAL_GIT_SETTINGS, data)
621 = self._get_settings_keys(self.GLOBAL_GIT_SETTINGS, data)
623
622
624 self._create_or_update_ui(
623 self._create_or_update_ui(
625 self.global_settings, *lfs_enabled, value=data[lfs_enabled_key],
624 self.global_settings, *lfs_enabled, value=data[lfs_enabled_key],
626 active=data[lfs_enabled_key])
625 active=data[lfs_enabled_key])
627 self._create_or_update_ui(
626 self._create_or_update_ui(
628 self.global_settings, *lfs_store_location,
627 self.global_settings, *lfs_store_location,
629 value=data[lfs_store_location_key])
628 value=data[lfs_store_location_key])
630
629
631 def create_or_update_global_svn_settings(self, data):
630 def create_or_update_global_svn_settings(self, data):
632 # branch/tags patterns
631 # branch/tags patterns
633 self._create_svn_settings(self.global_settings, data)
632 self._create_svn_settings(self.global_settings, data)
634
633
635 http_requests_enabled, http_server_url = self.GLOBAL_SVN_SETTINGS
634 http_requests_enabled, http_server_url = self.GLOBAL_SVN_SETTINGS
636 http_requests_enabled_key, http_server_url_key = self._get_settings_keys(
635 http_requests_enabled_key, http_server_url_key = self._get_settings_keys(
637 self.GLOBAL_SVN_SETTINGS, data)
636 self.GLOBAL_SVN_SETTINGS, data)
638
637
639 self._create_or_update_ui(
638 self._create_or_update_ui(
640 self.global_settings, *http_requests_enabled,
639 self.global_settings, *http_requests_enabled,
641 value=safe_str(data[http_requests_enabled_key]))
640 value=safe_str(data[http_requests_enabled_key]))
642 self._create_or_update_ui(
641 self._create_or_update_ui(
643 self.global_settings, *http_server_url,
642 self.global_settings, *http_server_url,
644 value=data[http_server_url_key])
643 value=data[http_server_url_key])
645
644
646 def update_global_ssl_setting(self, value):
645 def update_global_ssl_setting(self, value):
647 self._create_or_update_ui(
646 self._create_or_update_ui(
648 self.global_settings, *self.SSL_SETTING, value=value)
647 self.global_settings, *self.SSL_SETTING, value=value)
649
648
650 def update_global_path_setting(self, value):
649 def update_global_path_setting(self, value):
651 self._create_or_update_ui(
650 self._create_or_update_ui(
652 self.global_settings, *self.PATH_SETTING, value=value)
651 self.global_settings, *self.PATH_SETTING, value=value)
653
652
654 @assert_repo_settings
653 @assert_repo_settings
655 def delete_repo_svn_pattern(self, id_):
654 def delete_repo_svn_pattern(self, id_):
656 ui = self.repo_settings.UiDbModel.get(id_)
655 ui = self.repo_settings.UiDbModel.get(id_)
657 if ui and ui.repository.repo_name == self.repo_settings.repo:
656 if ui and ui.repository.repo_name == self.repo_settings.repo:
658 # only delete if it's the same repo as initialized settings
657 # only delete if it's the same repo as initialized settings
659 self.repo_settings.delete_ui(id_)
658 self.repo_settings.delete_ui(id_)
660 else:
659 else:
661 # raise error as if we wouldn't find this option
660 # raise error as if we wouldn't find this option
662 self.repo_settings.delete_ui(-1)
661 self.repo_settings.delete_ui(-1)
663
662
664 def delete_global_svn_pattern(self, id_):
663 def delete_global_svn_pattern(self, id_):
665 self.global_settings.delete_ui(id_)
664 self.global_settings.delete_ui(id_)
666
665
667 @assert_repo_settings
666 @assert_repo_settings
668 def get_repo_ui_settings(self, section=None, key=None):
667 def get_repo_ui_settings(self, section=None, key=None):
669 global_uis = self.global_settings.get_ui(section, key)
668 global_uis = self.global_settings.get_ui(section, key)
670 repo_uis = self.repo_settings.get_ui(section, key)
669 repo_uis = self.repo_settings.get_ui(section, key)
671 filtered_repo_uis = self._filter_ui_settings(repo_uis)
670 filtered_repo_uis = self._filter_ui_settings(repo_uis)
672 filtered_repo_uis_keys = [
671 filtered_repo_uis_keys = [
673 (s.section, s.key) for s in filtered_repo_uis]
672 (s.section, s.key) for s in filtered_repo_uis]
674
673
675 def _is_global_ui_filtered(ui):
674 def _is_global_ui_filtered(ui):
676 return (
675 return (
677 (ui.section, ui.key) in filtered_repo_uis_keys
676 (ui.section, ui.key) in filtered_repo_uis_keys
678 or ui.section in self._svn_sections)
677 or ui.section in self._svn_sections)
679
678
680 filtered_global_uis = [
679 filtered_global_uis = [
681 ui for ui in global_uis if not _is_global_ui_filtered(ui)]
680 ui for ui in global_uis if not _is_global_ui_filtered(ui)]
682
681
683 return filtered_global_uis + filtered_repo_uis
682 return filtered_global_uis + filtered_repo_uis
684
683
685 def get_global_ui_settings(self, section=None, key=None):
684 def get_global_ui_settings(self, section=None, key=None):
686 return self.global_settings.get_ui(section, key)
685 return self.global_settings.get_ui(section, key)
687
686
688 def get_ui_settings_as_config_obj(self, section=None, key=None):
687 def get_ui_settings_as_config_obj(self, section=None, key=None):
689 config = base.Config()
688 config = base.Config()
690
689
691 ui_settings = self.get_ui_settings(section=section, key=key)
690 ui_settings = self.get_ui_settings(section=section, key=key)
692
691
693 for entry in ui_settings:
692 for entry in ui_settings:
694 config.set(entry.section, entry.key, entry.value)
693 config.set(entry.section, entry.key, entry.value)
695
694
696 return config
695 return config
697
696
698 def get_ui_settings(self, section=None, key=None):
697 def get_ui_settings(self, section=None, key=None):
699 if not self.repo_settings or self.inherit_global_settings:
698 if not self.repo_settings or self.inherit_global_settings:
700 return self.get_global_ui_settings(section, key)
699 return self.get_global_ui_settings(section, key)
701 else:
700 else:
702 return self.get_repo_ui_settings(section, key)
701 return self.get_repo_ui_settings(section, key)
703
702
704 def get_svn_patterns(self, section=None):
703 def get_svn_patterns(self, section=None):
705 if not self.repo_settings:
704 if not self.repo_settings:
706 return self.get_global_ui_settings(section)
705 return self.get_global_ui_settings(section)
707 else:
706 else:
708 return self.get_repo_ui_settings(section)
707 return self.get_repo_ui_settings(section)
709
708
710 @assert_repo_settings
709 @assert_repo_settings
711 def get_repo_general_settings(self):
710 def get_repo_general_settings(self):
712 global_settings = self.global_settings.get_all_settings()
711 global_settings = self.global_settings.get_all_settings()
713 repo_settings = self.repo_settings.get_all_settings()
712 repo_settings = self.repo_settings.get_all_settings()
714 filtered_repo_settings = self._filter_general_settings(repo_settings)
713 filtered_repo_settings = self._filter_general_settings(repo_settings)
715 global_settings.update(filtered_repo_settings)
714 global_settings.update(filtered_repo_settings)
716 return global_settings
715 return global_settings
717
716
718 def get_global_general_settings(self):
717 def get_global_general_settings(self):
719 return self.global_settings.get_all_settings()
718 return self.global_settings.get_all_settings()
720
719
721 def get_general_settings(self):
720 def get_general_settings(self):
722 if not self.repo_settings or self.inherit_global_settings:
721 if not self.repo_settings or self.inherit_global_settings:
723 return self.get_global_general_settings()
722 return self.get_global_general_settings()
724 else:
723 else:
725 return self.get_repo_general_settings()
724 return self.get_repo_general_settings()
726
725
727 def get_repos_location(self):
726 def get_repos_location(self):
728 return self.global_settings.get_ui_by_key('/').ui_value
727 return self.global_settings.get_ui_by_key('/').ui_value
729
728
730 def _filter_ui_settings(self, settings):
729 def _filter_ui_settings(self, settings):
731 filtered_settings = [
730 filtered_settings = [
732 s for s in settings if self._should_keep_setting(s)]
731 s for s in settings if self._should_keep_setting(s)]
733 return filtered_settings
732 return filtered_settings
734
733
735 def _should_keep_setting(self, setting):
734 def _should_keep_setting(self, setting):
736 keep = (
735 keep = (
737 (setting.section, setting.key) in self._ui_settings or
736 (setting.section, setting.key) in self._ui_settings or
738 setting.section in self._svn_sections)
737 setting.section in self._svn_sections)
739 return keep
738 return keep
740
739
741 def _filter_general_settings(self, settings):
740 def _filter_general_settings(self, settings):
742 keys = ['rhodecode_{}'.format(key) for key in self.GENERAL_SETTINGS]
741 keys = ['rhodecode_{}'.format(key) for key in self.GENERAL_SETTINGS]
743 return {
742 return {
744 k: settings[k]
743 k: settings[k]
745 for k in settings if k in keys}
744 for k in settings if k in keys}
746
745
747 def _collect_all_settings(self, global_=False):
746 def _collect_all_settings(self, global_=False):
748 settings = self.global_settings if global_ else self.repo_settings
747 settings = self.global_settings if global_ else self.repo_settings
749 result = {}
748 result = {}
750
749
751 for section, key in self._ui_settings:
750 for section, key in self._ui_settings:
752 ui = settings.get_ui_by_section_and_key(section, key)
751 ui = settings.get_ui_by_section_and_key(section, key)
753 result_key = self._get_form_ui_key(section, key)
752 result_key = self._get_form_ui_key(section, key)
754
753
755 if ui:
754 if ui:
756 if section in ('hooks', 'extensions'):
755 if section in ('hooks', 'extensions'):
757 result[result_key] = ui.ui_active
756 result[result_key] = ui.ui_active
758 elif result_key in ['vcs_git_lfs_enabled']:
757 elif result_key in ['vcs_git_lfs_enabled']:
759 result[result_key] = ui.ui_active
758 result[result_key] = ui.ui_active
760 else:
759 else:
761 result[result_key] = ui.ui_value
760 result[result_key] = ui.ui_value
762
761
763 for name in self.GENERAL_SETTINGS:
762 for name in self.GENERAL_SETTINGS:
764 setting = settings.get_setting_by_name(name)
763 setting = settings.get_setting_by_name(name)
765 if setting:
764 if setting:
766 result_key = 'rhodecode_{}'.format(name)
765 result_key = 'rhodecode_{}'.format(name)
767 result[result_key] = setting.app_settings_value
766 result[result_key] = setting.app_settings_value
768
767
769 return result
768 return result
770
769
771 def _get_form_ui_key(self, section, key):
770 def _get_form_ui_key(self, section, key):
772 return '{section}_{key}'.format(
771 return '{section}_{key}'.format(
773 section=section, key=key.replace('.', '_'))
772 section=section, key=key.replace('.', '_'))
774
773
775 def _create_or_update_ui(
774 def _create_or_update_ui(
776 self, settings, section, key, value=None, active=None):
775 self, settings, section, key, value=None, active=None):
777 ui = settings.get_ui_by_section_and_key(section, key)
776 ui = settings.get_ui_by_section_and_key(section, key)
778 if not ui:
777 if not ui:
779 active = True if active is None else active
778 active = True if active is None else active
780 settings.create_ui_section_value(
779 settings.create_ui_section_value(
781 section, value, key=key, active=active)
780 section, value, key=key, active=active)
782 else:
781 else:
783 if active is not None:
782 if active is not None:
784 ui.ui_active = active
783 ui.ui_active = active
785 if value is not None:
784 if value is not None:
786 ui.ui_value = value
785 ui.ui_value = value
787 Session().add(ui)
786 Session().add(ui)
788
787
789 def _create_svn_settings(self, settings, data):
788 def _create_svn_settings(self, settings, data):
790 svn_settings = {
789 svn_settings = {
791 'new_svn_branch': self.SVN_BRANCH_SECTION,
790 'new_svn_branch': self.SVN_BRANCH_SECTION,
792 'new_svn_tag': self.SVN_TAG_SECTION
791 'new_svn_tag': self.SVN_TAG_SECTION
793 }
792 }
794 for key in svn_settings:
793 for key in svn_settings:
795 if data.get(key):
794 if data.get(key):
796 settings.create_ui_section_value(svn_settings[key], data[key])
795 settings.create_ui_section_value(svn_settings[key], data[key])
797
796
798 def _create_or_update_general_settings(self, settings, data):
797 def _create_or_update_general_settings(self, settings, data):
799 for name in self.GENERAL_SETTINGS:
798 for name in self.GENERAL_SETTINGS:
800 data_key = 'rhodecode_{}'.format(name)
799 data_key = 'rhodecode_{}'.format(name)
801 if data_key not in data:
800 if data_key not in data:
802 raise ValueError(
801 raise ValueError(
803 'The given data does not contain {} key'.format(data_key))
802 'The given data does not contain {} key'.format(data_key))
804 setting = settings.create_or_update_setting(
803 setting = settings.create_or_update_setting(
805 name, data[data_key], 'bool')
804 name, data[data_key], 'bool')
806 Session().add(setting)
805 Session().add(setting)
807
806
808 def _get_settings_keys(self, settings, data):
807 def _get_settings_keys(self, settings, data):
809 data_keys = [self._get_form_ui_key(*s) for s in settings]
808 data_keys = [self._get_form_ui_key(*s) for s in settings]
810 for data_key in data_keys:
809 for data_key in data_keys:
811 if data_key not in data:
810 if data_key not in data:
812 raise ValueError(
811 raise ValueError(
813 'The given data does not contain {} key'.format(data_key))
812 'The given data does not contain {} key'.format(data_key))
814 return data_keys
813 return data_keys
815
814
816 def create_largeobjects_dirs_if_needed(self, repo_store_path):
815 def create_largeobjects_dirs_if_needed(self, repo_store_path):
817 """
816 """
818 This is subscribed to the `pyramid.events.ApplicationCreated` event. It
817 This is subscribed to the `pyramid.events.ApplicationCreated` event. It
819 does a repository scan if enabled in the settings.
818 does a repository scan if enabled in the settings.
820 """
819 """
821
820
822 from rhodecode.lib.vcs.backends.hg import largefiles_store
821 from rhodecode.lib.vcs.backends.hg import largefiles_store
823 from rhodecode.lib.vcs.backends.git import lfs_store
822 from rhodecode.lib.vcs.backends.git import lfs_store
824
823
825 paths = [
824 paths = [
826 largefiles_store(repo_store_path),
825 largefiles_store(repo_store_path),
827 lfs_store(repo_store_path)]
826 lfs_store(repo_store_path)]
828
827
829 for path in paths:
828 for path in paths:
830 if os.path.isdir(path):
829 if os.path.isdir(path):
831 continue
830 continue
832 if os.path.isfile(path):
831 if os.path.isfile(path):
833 continue
832 continue
834 # not a file nor dir, we try to create it
833 # not a file nor dir, we try to create it
835 try:
834 try:
836 os.makedirs(path)
835 os.makedirs(path)
837 except Exception:
836 except Exception:
838 log.warning('Failed to create largefiles dir:%s', path)
837 log.warning('Failed to create largefiles dir:%s', path)
@@ -1,553 +1,553 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2018 RhodeCode GmbH
3 # Copyright (C) 2010-2018 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import datetime
21 import datetime
22 from urllib2 import URLError
22 from urllib2 import URLError
23
23
24 import mock
24 import mock
25 import pytest
25 import pytest
26
26
27 from rhodecode.lib.vcs import backends
27 from rhodecode.lib.vcs import backends
28 from rhodecode.lib.vcs.backends.base import (
28 from rhodecode.lib.vcs.backends.base import (
29 Config, BaseInMemoryCommit, Reference, MergeResponse, MergeFailureReason)
29 Config, BaseInMemoryCommit, Reference, MergeResponse, MergeFailureReason)
30 from rhodecode.lib.vcs.exceptions import VCSError, RepositoryError
30 from rhodecode.lib.vcs.exceptions import VCSError, RepositoryError
31 from rhodecode.lib.vcs.nodes import FileNode
31 from rhodecode.lib.vcs.nodes import FileNode
32 from rhodecode.tests.vcs.conftest import BackendTestMixin
32 from rhodecode.tests.vcs.conftest import BackendTestMixin
33 from rhodecode.tests import repo_id_generator
33 from rhodecode.tests import repo_id_generator
34
34
35
35
36 @pytest.mark.usefixtures("vcs_repository_support")
36 @pytest.mark.usefixtures("vcs_repository_support")
37 class TestRepositoryBase(BackendTestMixin):
37 class TestRepositoryBase(BackendTestMixin):
38 recreate_repo_per_test = False
38 recreate_repo_per_test = False
39
39
40 def test_init_accepts_unicode_path(self, tmpdir):
40 def test_init_accepts_unicode_path(self, tmpdir):
41 path = unicode(tmpdir.join(u'unicode Γ€'))
41 path = unicode(tmpdir.join(u'unicode Γ€'))
42 self.Backend(path, create=True)
42 self.Backend(path, create=True)
43
43
44 def test_init_accepts_str_path(self, tmpdir):
44 def test_init_accepts_str_path(self, tmpdir):
45 path = str(tmpdir.join('str Γ€'))
45 path = str(tmpdir.join('str Γ€'))
46 self.Backend(path, create=True)
46 self.Backend(path, create=True)
47
47
48 def test_init_fails_if_path_does_not_exist(self, tmpdir):
48 def test_init_fails_if_path_does_not_exist(self, tmpdir):
49 path = unicode(tmpdir.join('i-do-not-exist'))
49 path = unicode(tmpdir.join('i-do-not-exist'))
50 with pytest.raises(VCSError):
50 with pytest.raises(VCSError):
51 self.Backend(path)
51 self.Backend(path)
52
52
53 def test_init_fails_if_path_is_not_a_valid_repository(self, tmpdir):
53 def test_init_fails_if_path_is_not_a_valid_repository(self, tmpdir):
54 path = unicode(tmpdir.mkdir(u'unicode Γ€'))
54 path = unicode(tmpdir.mkdir(u'unicode Γ€'))
55 with pytest.raises(VCSError):
55 with pytest.raises(VCSError):
56 self.Backend(path)
56 self.Backend(path)
57
57
58 def test_has_commits_attribute(self):
58 def test_has_commits_attribute(self):
59 self.repo.commit_ids
59 self.repo.commit_ids
60
60
61 def test_name(self):
61 def test_name(self):
62 assert self.repo.name.startswith('vcs-test')
62 assert self.repo.name.startswith('vcs-test')
63
63
64 @pytest.mark.backends("hg", "git")
64 @pytest.mark.backends("hg", "git")
65 def test_has_default_branch_name(self):
65 def test_has_default_branch_name(self):
66 assert self.repo.DEFAULT_BRANCH_NAME is not None
66 assert self.repo.DEFAULT_BRANCH_NAME is not None
67
67
68 @pytest.mark.backends("svn")
68 @pytest.mark.backends("svn")
69 def test_has_no_default_branch_name(self):
69 def test_has_no_default_branch_name(self):
70 assert self.repo.DEFAULT_BRANCH_NAME is None
70 assert self.repo.DEFAULT_BRANCH_NAME is None
71
71
72 def test_has_empty_commit(self):
72 def test_has_empty_commit(self):
73 assert self.repo.EMPTY_COMMIT_ID is not None
73 assert self.repo.EMPTY_COMMIT_ID is not None
74 assert self.repo.EMPTY_COMMIT is not None
74 assert self.repo.EMPTY_COMMIT is not None
75
75
76 def test_empty_changeset_is_deprecated(self):
76 def test_empty_changeset_is_deprecated(self):
77 def get_empty_changeset(repo):
77 def get_empty_changeset(repo):
78 return repo.EMPTY_CHANGESET
78 return repo.EMPTY_CHANGESET
79 pytest.deprecated_call(get_empty_changeset, self.repo)
79 pytest.deprecated_call(get_empty_changeset, self.repo)
80
80
81 def test_bookmarks(self):
81 def test_bookmarks(self):
82 assert len(self.repo.bookmarks) == 0
82 assert len(self.repo.bookmarks) == 0
83
83
84 # TODO: Cover two cases: Local repo path, remote URL
84 # TODO: Cover two cases: Local repo path, remote URL
85 def test_check_url(self):
85 def test_check_url(self):
86 config = Config()
86 config = Config()
87 assert self.Backend.check_url(self.repo.path, config)
87 assert self.Backend.check_url(self.repo.path, config)
88
88
89 def test_check_url_invalid(self):
89 def test_check_url_invalid(self):
90 config = Config()
90 config = Config()
91 with pytest.raises(URLError):
91 with pytest.raises(URLError):
92 self.Backend.check_url(self.repo.path + "invalid", config)
92 self.Backend.check_url(self.repo.path + "invalid", config)
93
93
94 def test_get_contact(self):
94 def test_get_contact(self):
95 assert self.repo.contact
95 assert self.repo.contact
96
96
97 def test_get_description(self):
97 def test_get_description(self):
98 assert self.repo.description
98 assert self.repo.description
99
99
100 def test_get_hook_location(self):
100 def test_get_hook_location(self):
101 assert len(self.repo.get_hook_location()) != 0
101 assert len(self.repo.get_hook_location()) != 0
102
102
103 def test_last_change(self, local_dt_to_utc):
103 def test_last_change(self, local_dt_to_utc):
104 assert self.repo.last_change >= local_dt_to_utc(
104 assert self.repo.last_change >= local_dt_to_utc(
105 datetime.datetime(2010, 1, 1, 21, 0))
105 datetime.datetime(2010, 1, 1, 21, 0))
106
106
107 def test_last_change_in_empty_repository(self, vcsbackend, local_dt_to_utc):
107 def test_last_change_in_empty_repository(self, vcsbackend, local_dt_to_utc):
108 delta = datetime.timedelta(seconds=1)
108 delta = datetime.timedelta(seconds=1)
109
109
110 start = local_dt_to_utc(datetime.datetime.now())
110 start = local_dt_to_utc(datetime.datetime.now())
111 empty_repo = vcsbackend.create_repo()
111 empty_repo = vcsbackend.create_repo()
112 now = local_dt_to_utc(datetime.datetime.now())
112 now = local_dt_to_utc(datetime.datetime.now())
113 assert empty_repo.last_change >= start - delta
113 assert empty_repo.last_change >= start - delta
114 assert empty_repo.last_change <= now + delta
114 assert empty_repo.last_change <= now + delta
115
115
116 def test_repo_equality(self):
116 def test_repo_equality(self):
117 assert self.repo == self.repo
117 assert self.repo == self.repo
118
118
119 def test_repo_equality_broken_object(self):
119 def test_repo_equality_broken_object(self):
120 import copy
120 import copy
121 _repo = copy.copy(self.repo)
121 _repo = copy.copy(self.repo)
122 delattr(_repo, 'path')
122 delattr(_repo, 'path')
123 assert self.repo != _repo
123 assert self.repo != _repo
124
124
125 def test_repo_equality_other_object(self):
125 def test_repo_equality_other_object(self):
126 class dummy(object):
126 class dummy(object):
127 path = self.repo.path
127 path = self.repo.path
128 assert self.repo != dummy()
128 assert self.repo != dummy()
129
129
130 def test_get_commit_is_implemented(self):
130 def test_get_commit_is_implemented(self):
131 self.repo.get_commit()
131 self.repo.get_commit()
132
132
133 def test_get_commits_is_implemented(self):
133 def test_get_commits_is_implemented(self):
134 commit_iter = iter(self.repo.get_commits())
134 commit_iter = iter(self.repo.get_commits())
135 commit = next(commit_iter)
135 commit = next(commit_iter)
136 assert commit.idx == 0
136 assert commit.idx == 0
137
137
138 def test_supports_iteration(self):
138 def test_supports_iteration(self):
139 repo_iter = iter(self.repo)
139 repo_iter = iter(self.repo)
140 commit = next(repo_iter)
140 commit = next(repo_iter)
141 assert commit.idx == 0
141 assert commit.idx == 0
142
142
143 def test_in_memory_commit(self):
143 def test_in_memory_commit(self):
144 imc = self.repo.in_memory_commit
144 imc = self.repo.in_memory_commit
145 assert isinstance(imc, BaseInMemoryCommit)
145 assert isinstance(imc, BaseInMemoryCommit)
146
146
147 @pytest.mark.backends("hg")
147 @pytest.mark.backends("hg")
148 def test__get_url_unicode(self):
148 def test__get_url_unicode(self):
149 url = u'/home/repos/malmΓΆ'
149 url = u'/home/repos/malmΓΆ'
150 assert self.repo._get_url(url)
150 assert self.repo._get_url(url)
151
151
152
152
153 @pytest.mark.usefixtures("vcs_repository_support")
153 @pytest.mark.usefixtures("vcs_repository_support")
154 class TestDeprecatedRepositoryAPI(BackendTestMixin):
154 class TestDeprecatedRepositoryAPI(BackendTestMixin):
155 recreate_repo_per_test = False
155 recreate_repo_per_test = False
156
156
157 def test_revisions_is_deprecated(self):
157 def test_revisions_is_deprecated(self):
158 def get_revisions(repo):
158 def get_revisions(repo):
159 return repo.revisions
159 return repo.revisions
160 pytest.deprecated_call(get_revisions, self.repo)
160 pytest.deprecated_call(get_revisions, self.repo)
161
161
162 def test_get_changeset_is_deprecated(self):
162 def test_get_changeset_is_deprecated(self):
163 pytest.deprecated_call(self.repo.get_changeset)
163 pytest.deprecated_call(self.repo.get_changeset)
164
164
165 def test_get_changesets_is_deprecated(self):
165 def test_get_changesets_is_deprecated(self):
166 pytest.deprecated_call(self.repo.get_changesets)
166 pytest.deprecated_call(self.repo.get_changesets)
167
167
168 def test_in_memory_changeset_is_deprecated(self):
168 def test_in_memory_changeset_is_deprecated(self):
169 def get_imc(repo):
169 def get_imc(repo):
170 return repo.in_memory_changeset
170 return repo.in_memory_changeset
171 pytest.deprecated_call(get_imc, self.repo)
171 pytest.deprecated_call(get_imc, self.repo)
172
172
173
173
174 # TODO: these tests are incomplete, must check the resulting compare result for
174 # TODO: these tests are incomplete, must check the resulting compare result for
175 # correcteness
175 # correcteness
176 class TestRepositoryCompare:
176 class TestRepositoryCompare:
177
177
178 @pytest.mark.parametrize('merge', [True, False])
178 @pytest.mark.parametrize('merge', [True, False])
179 def test_compare_commits_of_same_repository(self, vcsbackend, merge):
179 def test_compare_commits_of_same_repository(self, vcsbackend, merge):
180 target_repo = vcsbackend.create_repo(number_of_commits=5)
180 target_repo = vcsbackend.create_repo(number_of_commits=5)
181 target_repo.compare(
181 target_repo.compare(
182 target_repo[1].raw_id, target_repo[3].raw_id, target_repo,
182 target_repo[1].raw_id, target_repo[3].raw_id, target_repo,
183 merge=merge)
183 merge=merge)
184
184
185 @pytest.mark.xfail_backends('svn')
185 @pytest.mark.xfail_backends('svn')
186 @pytest.mark.parametrize('merge', [True, False])
186 @pytest.mark.parametrize('merge', [True, False])
187 def test_compare_cloned_repositories(self, vcsbackend, merge):
187 def test_compare_cloned_repositories(self, vcsbackend, merge):
188 target_repo = vcsbackend.create_repo(number_of_commits=5)
188 target_repo = vcsbackend.create_repo(number_of_commits=5)
189 source_repo = vcsbackend.clone_repo(target_repo)
189 source_repo = vcsbackend.clone_repo(target_repo)
190 assert target_repo != source_repo
190 assert target_repo != source_repo
191
191
192 vcsbackend.add_file(source_repo, 'newfile', 'somecontent')
192 vcsbackend.add_file(source_repo, 'newfile', 'somecontent')
193 source_commit = source_repo.get_commit()
193 source_commit = source_repo.get_commit()
194
194
195 target_repo.compare(
195 target_repo.compare(
196 target_repo[1].raw_id, source_repo[3].raw_id, source_repo,
196 target_repo[1].raw_id, source_repo[3].raw_id, source_repo,
197 merge=merge)
197 merge=merge)
198
198
199 @pytest.mark.xfail_backends('svn')
199 @pytest.mark.xfail_backends('svn')
200 @pytest.mark.parametrize('merge', [True, False])
200 @pytest.mark.parametrize('merge', [True, False])
201 def test_compare_unrelated_repositories(self, vcsbackend, merge):
201 def test_compare_unrelated_repositories(self, vcsbackend, merge):
202 orig = vcsbackend.create_repo(number_of_commits=5)
202 orig = vcsbackend.create_repo(number_of_commits=5)
203 unrelated = vcsbackend.create_repo(number_of_commits=5)
203 unrelated = vcsbackend.create_repo(number_of_commits=5)
204 assert orig != unrelated
204 assert orig != unrelated
205
205
206 orig.compare(
206 orig.compare(
207 orig[1].raw_id, unrelated[3].raw_id, unrelated, merge=merge)
207 orig[1].raw_id, unrelated[3].raw_id, unrelated, merge=merge)
208
208
209
209
210 class TestRepositoryGetCommonAncestor:
210 class TestRepositoryGetCommonAncestor:
211
211
212 def test_get_common_ancestor_from_same_repo_existing(self, vcsbackend):
212 def test_get_common_ancestor_from_same_repo_existing(self, vcsbackend):
213 target_repo = vcsbackend.create_repo(number_of_commits=5)
213 target_repo = vcsbackend.create_repo(number_of_commits=5)
214
214
215 expected_ancestor = target_repo[2].raw_id
215 expected_ancestor = target_repo[2].raw_id
216
216
217 assert target_repo.get_common_ancestor(
217 assert target_repo.get_common_ancestor(
218 commit_id1=target_repo[2].raw_id,
218 commit_id1=target_repo[2].raw_id,
219 commit_id2=target_repo[4].raw_id,
219 commit_id2=target_repo[4].raw_id,
220 repo2=target_repo
220 repo2=target_repo
221 ) == expected_ancestor
221 ) == expected_ancestor
222
222
223 assert target_repo.get_common_ancestor(
223 assert target_repo.get_common_ancestor(
224 commit_id1=target_repo[4].raw_id,
224 commit_id1=target_repo[4].raw_id,
225 commit_id2=target_repo[2].raw_id,
225 commit_id2=target_repo[2].raw_id,
226 repo2=target_repo
226 repo2=target_repo
227 ) == expected_ancestor
227 ) == expected_ancestor
228
228
229 @pytest.mark.xfail_backends("svn")
229 @pytest.mark.xfail_backends("svn")
230 def test_get_common_ancestor_from_cloned_repo_existing(self, vcsbackend):
230 def test_get_common_ancestor_from_cloned_repo_existing(self, vcsbackend):
231 target_repo = vcsbackend.create_repo(number_of_commits=5)
231 target_repo = vcsbackend.create_repo(number_of_commits=5)
232 source_repo = vcsbackend.clone_repo(target_repo)
232 source_repo = vcsbackend.clone_repo(target_repo)
233 assert target_repo != source_repo
233 assert target_repo != source_repo
234
234
235 vcsbackend.add_file(source_repo, 'newfile', 'somecontent')
235 vcsbackend.add_file(source_repo, 'newfile', 'somecontent')
236 source_commit = source_repo.get_commit()
236 source_commit = source_repo.get_commit()
237
237
238 expected_ancestor = target_repo[4].raw_id
238 expected_ancestor = target_repo[4].raw_id
239
239
240 assert target_repo.get_common_ancestor(
240 assert target_repo.get_common_ancestor(
241 commit_id1=target_repo[4].raw_id,
241 commit_id1=target_repo[4].raw_id,
242 commit_id2=source_commit.raw_id,
242 commit_id2=source_commit.raw_id,
243 repo2=source_repo
243 repo2=source_repo
244 ) == expected_ancestor
244 ) == expected_ancestor
245
245
246 assert target_repo.get_common_ancestor(
246 assert target_repo.get_common_ancestor(
247 commit_id1=source_commit.raw_id,
247 commit_id1=source_commit.raw_id,
248 commit_id2=target_repo[4].raw_id,
248 commit_id2=target_repo[4].raw_id,
249 repo2=target_repo
249 repo2=target_repo
250 ) == expected_ancestor
250 ) == expected_ancestor
251
251
252 @pytest.mark.xfail_backends("svn")
252 @pytest.mark.xfail_backends("svn")
253 def test_get_common_ancestor_from_unrelated_repo_missing(self, vcsbackend):
253 def test_get_common_ancestor_from_unrelated_repo_missing(self, vcsbackend):
254 original = vcsbackend.create_repo(number_of_commits=5)
254 original = vcsbackend.create_repo(number_of_commits=5)
255 unrelated = vcsbackend.create_repo(number_of_commits=5)
255 unrelated = vcsbackend.create_repo(number_of_commits=5)
256 assert original != unrelated
256 assert original != unrelated
257
257
258 assert original.get_common_ancestor(
258 assert original.get_common_ancestor(
259 commit_id1=original[0].raw_id,
259 commit_id1=original[0].raw_id,
260 commit_id2=unrelated[0].raw_id,
260 commit_id2=unrelated[0].raw_id,
261 repo2=unrelated
261 repo2=unrelated
262 ) == None
262 ) is None
263
263
264 assert original.get_common_ancestor(
264 assert original.get_common_ancestor(
265 commit_id1=original[-1].raw_id,
265 commit_id1=original[-1].raw_id,
266 commit_id2=unrelated[-1].raw_id,
266 commit_id2=unrelated[-1].raw_id,
267 repo2=unrelated
267 repo2=unrelated
268 ) == None
268 ) is None
269
269
270
270
271 @pytest.mark.backends("git", "hg")
271 @pytest.mark.backends("git", "hg")
272 class TestRepositoryMerge(object):
272 class TestRepositoryMerge(object):
273 def prepare_for_success(self, vcsbackend):
273 def prepare_for_success(self, vcsbackend):
274 self.target_repo = vcsbackend.create_repo(number_of_commits=1)
274 self.target_repo = vcsbackend.create_repo(number_of_commits=1)
275 self.source_repo = vcsbackend.clone_repo(self.target_repo)
275 self.source_repo = vcsbackend.clone_repo(self.target_repo)
276 vcsbackend.add_file(self.target_repo, 'README_MERGE1', 'Version 1')
276 vcsbackend.add_file(self.target_repo, 'README_MERGE1', 'Version 1')
277 vcsbackend.add_file(self.source_repo, 'README_MERGE2', 'Version 2')
277 vcsbackend.add_file(self.source_repo, 'README_MERGE2', 'Version 2')
278 imc = self.source_repo.in_memory_commit
278 imc = self.source_repo.in_memory_commit
279 imc.add(FileNode('file_x', content=self.source_repo.name))
279 imc.add(FileNode('file_x', content=self.source_repo.name))
280 imc.commit(
280 imc.commit(
281 message=u'Automatic commit from repo merge test',
281 message=u'Automatic commit from repo merge test',
282 author=u'Automatic')
282 author=u'Automatic')
283 self.target_commit = self.target_repo.get_commit()
283 self.target_commit = self.target_repo.get_commit()
284 self.source_commit = self.source_repo.get_commit()
284 self.source_commit = self.source_repo.get_commit()
285 # This only works for Git and Mercurial
285 # This only works for Git and Mercurial
286 default_branch = self.target_repo.DEFAULT_BRANCH_NAME
286 default_branch = self.target_repo.DEFAULT_BRANCH_NAME
287 self.target_ref = Reference(
287 self.target_ref = Reference(
288 'branch', default_branch, self.target_commit.raw_id)
288 'branch', default_branch, self.target_commit.raw_id)
289 self.source_ref = Reference(
289 self.source_ref = Reference(
290 'branch', default_branch, self.source_commit.raw_id)
290 'branch', default_branch, self.source_commit.raw_id)
291 self.workspace_id = 'test-merge'
291 self.workspace_id = 'test-merge'
292 self.repo_id = repo_id_generator(self.target_repo.path)
292 self.repo_id = repo_id_generator(self.target_repo.path)
293
293
294 def prepare_for_conflict(self, vcsbackend):
294 def prepare_for_conflict(self, vcsbackend):
295 self.target_repo = vcsbackend.create_repo(number_of_commits=1)
295 self.target_repo = vcsbackend.create_repo(number_of_commits=1)
296 self.source_repo = vcsbackend.clone_repo(self.target_repo)
296 self.source_repo = vcsbackend.clone_repo(self.target_repo)
297 vcsbackend.add_file(self.target_repo, 'README_MERGE', 'Version 1')
297 vcsbackend.add_file(self.target_repo, 'README_MERGE', 'Version 1')
298 vcsbackend.add_file(self.source_repo, 'README_MERGE', 'Version 2')
298 vcsbackend.add_file(self.source_repo, 'README_MERGE', 'Version 2')
299 self.target_commit = self.target_repo.get_commit()
299 self.target_commit = self.target_repo.get_commit()
300 self.source_commit = self.source_repo.get_commit()
300 self.source_commit = self.source_repo.get_commit()
301 # This only works for Git and Mercurial
301 # This only works for Git and Mercurial
302 default_branch = self.target_repo.DEFAULT_BRANCH_NAME
302 default_branch = self.target_repo.DEFAULT_BRANCH_NAME
303 self.target_ref = Reference(
303 self.target_ref = Reference(
304 'branch', default_branch, self.target_commit.raw_id)
304 'branch', default_branch, self.target_commit.raw_id)
305 self.source_ref = Reference(
305 self.source_ref = Reference(
306 'branch', default_branch, self.source_commit.raw_id)
306 'branch', default_branch, self.source_commit.raw_id)
307 self.workspace_id = 'test-merge'
307 self.workspace_id = 'test-merge'
308 self.repo_id = repo_id_generator(self.target_repo.path)
308 self.repo_id = repo_id_generator(self.target_repo.path)
309
309
310 def test_merge_success(self, vcsbackend):
310 def test_merge_success(self, vcsbackend):
311 self.prepare_for_success(vcsbackend)
311 self.prepare_for_success(vcsbackend)
312
312
313 merge_response = self.target_repo.merge(
313 merge_response = self.target_repo.merge(
314 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
314 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
315 self.source_ref,
315 self.source_ref,
316 'test user', 'test@rhodecode.com', 'merge message 1',
316 'test user', 'test@rhodecode.com', 'merge message 1',
317 dry_run=False)
317 dry_run=False)
318 expected_merge_response = MergeResponse(
318 expected_merge_response = MergeResponse(
319 True, True, merge_response.merge_ref,
319 True, True, merge_response.merge_ref,
320 MergeFailureReason.NONE)
320 MergeFailureReason.NONE)
321 assert merge_response == expected_merge_response
321 assert merge_response == expected_merge_response
322
322
323 target_repo = backends.get_backend(vcsbackend.alias)(
323 target_repo = backends.get_backend(vcsbackend.alias)(
324 self.target_repo.path)
324 self.target_repo.path)
325 target_commits = list(target_repo.get_commits())
325 target_commits = list(target_repo.get_commits())
326 commit_ids = [c.raw_id for c in target_commits[:-1]]
326 commit_ids = [c.raw_id for c in target_commits[:-1]]
327 assert self.source_ref.commit_id in commit_ids
327 assert self.source_ref.commit_id in commit_ids
328 assert self.target_ref.commit_id in commit_ids
328 assert self.target_ref.commit_id in commit_ids
329
329
330 merge_commit = target_commits[-1]
330 merge_commit = target_commits[-1]
331 assert merge_commit.raw_id == merge_response.merge_ref.commit_id
331 assert merge_commit.raw_id == merge_response.merge_ref.commit_id
332 assert merge_commit.message.strip() == 'merge message 1'
332 assert merge_commit.message.strip() == 'merge message 1'
333 assert merge_commit.author == 'test user <test@rhodecode.com>'
333 assert merge_commit.author == 'test user <test@rhodecode.com>'
334
334
335 # We call it twice so to make sure we can handle updates
335 # We call it twice so to make sure we can handle updates
336 target_ref = Reference(
336 target_ref = Reference(
337 self.target_ref.type, self.target_ref.name,
337 self.target_ref.type, self.target_ref.name,
338 merge_response.merge_ref.commit_id)
338 merge_response.merge_ref.commit_id)
339
339
340 merge_response = target_repo.merge(
340 merge_response = target_repo.merge(
341 self.repo_id, self.workspace_id, target_ref, self.source_repo, self.source_ref,
341 self.repo_id, self.workspace_id, target_ref, self.source_repo, self.source_ref,
342 'test user', 'test@rhodecode.com', 'merge message 2',
342 'test user', 'test@rhodecode.com', 'merge message 2',
343 dry_run=False)
343 dry_run=False)
344 expected_merge_response = MergeResponse(
344 expected_merge_response = MergeResponse(
345 True, True, merge_response.merge_ref,
345 True, True, merge_response.merge_ref,
346 MergeFailureReason.NONE)
346 MergeFailureReason.NONE)
347 assert merge_response == expected_merge_response
347 assert merge_response == expected_merge_response
348
348
349 target_repo = backends.get_backend(
349 target_repo = backends.get_backend(
350 vcsbackend.alias)(self.target_repo.path)
350 vcsbackend.alias)(self.target_repo.path)
351 merge_commit = target_repo.get_commit(
351 merge_commit = target_repo.get_commit(
352 merge_response.merge_ref.commit_id)
352 merge_response.merge_ref.commit_id)
353 assert merge_commit.message.strip() == 'merge message 1'
353 assert merge_commit.message.strip() == 'merge message 1'
354 assert merge_commit.author == 'test user <test@rhodecode.com>'
354 assert merge_commit.author == 'test user <test@rhodecode.com>'
355
355
356 def test_merge_success_dry_run(self, vcsbackend):
356 def test_merge_success_dry_run(self, vcsbackend):
357 self.prepare_for_success(vcsbackend)
357 self.prepare_for_success(vcsbackend)
358
358
359 merge_response = self.target_repo.merge(
359 merge_response = self.target_repo.merge(
360 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
360 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
361 self.source_ref, dry_run=True)
361 self.source_ref, dry_run=True)
362
362
363 # We call it twice so to make sure we can handle updates
363 # We call it twice so to make sure we can handle updates
364 merge_response_update = self.target_repo.merge(
364 merge_response_update = self.target_repo.merge(
365 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
365 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
366 self.source_ref, dry_run=True)
366 self.source_ref, dry_run=True)
367
367
368 # Multiple merges may differ in their commit id. Therefore we set the
368 # Multiple merges may differ in their commit id. Therefore we set the
369 # commit id to `None` before comparing the merge responses.
369 # commit id to `None` before comparing the merge responses.
370 merge_response = merge_response._replace(
370 merge_response = merge_response._replace(
371 merge_ref=merge_response.merge_ref._replace(commit_id=None))
371 merge_ref=merge_response.merge_ref._replace(commit_id=None))
372 merge_response_update = merge_response_update._replace(
372 merge_response_update = merge_response_update._replace(
373 merge_ref=merge_response_update.merge_ref._replace(commit_id=None))
373 merge_ref=merge_response_update.merge_ref._replace(commit_id=None))
374
374
375 assert merge_response == merge_response_update
375 assert merge_response == merge_response_update
376 assert merge_response.possible is True
376 assert merge_response.possible is True
377 assert merge_response.executed is False
377 assert merge_response.executed is False
378 assert merge_response.merge_ref
378 assert merge_response.merge_ref
379 assert merge_response.failure_reason is MergeFailureReason.NONE
379 assert merge_response.failure_reason is MergeFailureReason.NONE
380
380
381 @pytest.mark.parametrize('dry_run', [True, False])
381 @pytest.mark.parametrize('dry_run', [True, False])
382 def test_merge_conflict(self, vcsbackend, dry_run):
382 def test_merge_conflict(self, vcsbackend, dry_run):
383 self.prepare_for_conflict(vcsbackend)
383 self.prepare_for_conflict(vcsbackend)
384 expected_merge_response = MergeResponse(
384 expected_merge_response = MergeResponse(
385 False, False, None, MergeFailureReason.MERGE_FAILED)
385 False, False, None, MergeFailureReason.MERGE_FAILED)
386
386
387 merge_response = self.target_repo.merge(
387 merge_response = self.target_repo.merge(
388 self.repo_id, self.workspace_id, self.target_ref,
388 self.repo_id, self.workspace_id, self.target_ref,
389 self.source_repo, self.source_ref,
389 self.source_repo, self.source_ref,
390 'test_user', 'test@rhodecode.com', 'test message', dry_run=dry_run)
390 'test_user', 'test@rhodecode.com', 'test message', dry_run=dry_run)
391 assert merge_response == expected_merge_response
391 assert merge_response == expected_merge_response
392
392
393 # We call it twice so to make sure we can handle updates
393 # We call it twice so to make sure we can handle updates
394 merge_response = self.target_repo.merge(
394 merge_response = self.target_repo.merge(
395 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
395 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
396 self.source_ref,
396 self.source_ref,
397 'test_user', 'test@rhodecode.com', 'test message', dry_run=dry_run)
397 'test_user', 'test@rhodecode.com', 'test message', dry_run=dry_run)
398 assert merge_response == expected_merge_response
398 assert merge_response == expected_merge_response
399
399
400 def test_merge_target_is_not_head(self, vcsbackend):
400 def test_merge_target_is_not_head(self, vcsbackend):
401 self.prepare_for_success(vcsbackend)
401 self.prepare_for_success(vcsbackend)
402 expected_merge_response = MergeResponse(
402 expected_merge_response = MergeResponse(
403 False, False, None, MergeFailureReason.TARGET_IS_NOT_HEAD)
403 False, False, None, MergeFailureReason.TARGET_IS_NOT_HEAD)
404
404
405 target_ref = Reference(
405 target_ref = Reference(
406 self.target_ref.type, self.target_ref.name, '0' * 40)
406 self.target_ref.type, self.target_ref.name, '0' * 40)
407
407
408 merge_response = self.target_repo.merge(
408 merge_response = self.target_repo.merge(
409 self.repo_id, self.workspace_id, target_ref, self.source_repo,
409 self.repo_id, self.workspace_id, target_ref, self.source_repo,
410 self.source_ref, dry_run=True)
410 self.source_ref, dry_run=True)
411
411
412 assert merge_response == expected_merge_response
412 assert merge_response == expected_merge_response
413
413
414 def test_merge_missing_source_reference(self, vcsbackend):
414 def test_merge_missing_source_reference(self, vcsbackend):
415 self.prepare_for_success(vcsbackend)
415 self.prepare_for_success(vcsbackend)
416 expected_merge_response = MergeResponse(
416 expected_merge_response = MergeResponse(
417 False, False, None, MergeFailureReason.MISSING_SOURCE_REF)
417 False, False, None, MergeFailureReason.MISSING_SOURCE_REF)
418
418
419 source_ref = Reference(
419 source_ref = Reference(
420 self.source_ref.type, 'not_existing', self.source_ref.commit_id)
420 self.source_ref.type, 'not_existing', self.source_ref.commit_id)
421
421
422 merge_response = self.target_repo.merge(
422 merge_response = self.target_repo.merge(
423 self.repo_id, self.workspace_id, self.target_ref,
423 self.repo_id, self.workspace_id, self.target_ref,
424 self.source_repo, source_ref,
424 self.source_repo, source_ref,
425 dry_run=True)
425 dry_run=True)
426
426
427 assert merge_response == expected_merge_response
427 assert merge_response == expected_merge_response
428
428
429 def test_merge_raises_exception(self, vcsbackend):
429 def test_merge_raises_exception(self, vcsbackend):
430 self.prepare_for_success(vcsbackend)
430 self.prepare_for_success(vcsbackend)
431 expected_merge_response = MergeResponse(
431 expected_merge_response = MergeResponse(
432 False, False, None, MergeFailureReason.UNKNOWN)
432 False, False, None, MergeFailureReason.UNKNOWN)
433
433
434 with mock.patch.object(self.target_repo, '_merge_repo',
434 with mock.patch.object(self.target_repo, '_merge_repo',
435 side_effect=RepositoryError()):
435 side_effect=RepositoryError()):
436 merge_response = self.target_repo.merge(
436 merge_response = self.target_repo.merge(
437 self.repo_id, self.workspace_id, self.target_ref,
437 self.repo_id, self.workspace_id, self.target_ref,
438 self.source_repo, self.source_ref,
438 self.source_repo, self.source_ref,
439 dry_run=True)
439 dry_run=True)
440
440
441 assert merge_response == expected_merge_response
441 assert merge_response == expected_merge_response
442
442
443 def test_merge_invalid_user_name(self, vcsbackend):
443 def test_merge_invalid_user_name(self, vcsbackend):
444 repo = vcsbackend.create_repo(number_of_commits=1)
444 repo = vcsbackend.create_repo(number_of_commits=1)
445 ref = Reference('branch', 'master', 'not_used')
445 ref = Reference('branch', 'master', 'not_used')
446 workspace_id = 'test-errors-in-merge'
446 workspace_id = 'test-errors-in-merge'
447 repo_id = repo_id_generator(workspace_id)
447 repo_id = repo_id_generator(workspace_id)
448 with pytest.raises(ValueError):
448 with pytest.raises(ValueError):
449 repo.merge(repo_id, workspace_id, ref, self, ref)
449 repo.merge(repo_id, workspace_id, ref, self, ref)
450
450
451 def test_merge_invalid_user_email(self, vcsbackend):
451 def test_merge_invalid_user_email(self, vcsbackend):
452 repo = vcsbackend.create_repo(number_of_commits=1)
452 repo = vcsbackend.create_repo(number_of_commits=1)
453 ref = Reference('branch', 'master', 'not_used')
453 ref = Reference('branch', 'master', 'not_used')
454 workspace_id = 'test-errors-in-merge'
454 workspace_id = 'test-errors-in-merge'
455 repo_id = repo_id_generator(workspace_id)
455 repo_id = repo_id_generator(workspace_id)
456 with pytest.raises(ValueError):
456 with pytest.raises(ValueError):
457 repo.merge(
457 repo.merge(
458 repo_id, workspace_id, ref, self, ref, 'user name')
458 repo_id, workspace_id, ref, self, ref, 'user name')
459
459
460 def test_merge_invalid_message(self, vcsbackend):
460 def test_merge_invalid_message(self, vcsbackend):
461 repo = vcsbackend.create_repo(number_of_commits=1)
461 repo = vcsbackend.create_repo(number_of_commits=1)
462 ref = Reference('branch', 'master', 'not_used')
462 ref = Reference('branch', 'master', 'not_used')
463 workspace_id = 'test-errors-in-merge'
463 workspace_id = 'test-errors-in-merge'
464 repo_id = repo_id_generator(workspace_id)
464 repo_id = repo_id_generator(workspace_id)
465 with pytest.raises(ValueError):
465 with pytest.raises(ValueError):
466 repo.merge(
466 repo.merge(
467 repo_id, workspace_id, ref, self, ref,
467 repo_id, workspace_id, ref, self, ref,
468 'user name', 'user@email.com')
468 'user name', 'user@email.com')
469
469
470
470
471 @pytest.mark.usefixtures("vcs_repository_support")
471 @pytest.mark.usefixtures("vcs_repository_support")
472 class TestRepositoryStrip(BackendTestMixin):
472 class TestRepositoryStrip(BackendTestMixin):
473 recreate_repo_per_test = True
473 recreate_repo_per_test = True
474
474
475 @classmethod
475 @classmethod
476 def _get_commits(cls):
476 def _get_commits(cls):
477 commits = [
477 commits = [
478 {
478 {
479 'message': 'Initial commit',
479 'message': 'Initial commit',
480 'author': 'Joe Doe <joe.doe@example.com>',
480 'author': 'Joe Doe <joe.doe@example.com>',
481 'date': datetime.datetime(2010, 1, 1, 20),
481 'date': datetime.datetime(2010, 1, 1, 20),
482 'branch': 'master',
482 'branch': 'master',
483 'added': [
483 'added': [
484 FileNode('foobar', content='foobar'),
484 FileNode('foobar', content='foobar'),
485 FileNode('foobar2', content='foobar2'),
485 FileNode('foobar2', content='foobar2'),
486 ],
486 ],
487 },
487 },
488 ]
488 ]
489 for x in xrange(10):
489 for x in xrange(10):
490 commit_data = {
490 commit_data = {
491 'message': 'Changed foobar - commit%s' % x,
491 'message': 'Changed foobar - commit%s' % x,
492 'author': 'Jane Doe <jane.doe@example.com>',
492 'author': 'Jane Doe <jane.doe@example.com>',
493 'date': datetime.datetime(2010, 1, 1, 21, x),
493 'date': datetime.datetime(2010, 1, 1, 21, x),
494 'branch': 'master',
494 'branch': 'master',
495 'changed': [
495 'changed': [
496 FileNode('foobar', 'FOOBAR - %s' % x),
496 FileNode('foobar', 'FOOBAR - %s' % x),
497 ],
497 ],
498 }
498 }
499 commits.append(commit_data)
499 commits.append(commit_data)
500 return commits
500 return commits
501
501
502 @pytest.mark.backends("git", "hg")
502 @pytest.mark.backends("git", "hg")
503 def test_strip_commit(self):
503 def test_strip_commit(self):
504 tip = self.repo.get_commit()
504 tip = self.repo.get_commit()
505 assert tip.idx == 10
505 assert tip.idx == 10
506 self.repo.strip(tip.raw_id, self.repo.DEFAULT_BRANCH_NAME)
506 self.repo.strip(tip.raw_id, self.repo.DEFAULT_BRANCH_NAME)
507
507
508 tip = self.repo.get_commit()
508 tip = self.repo.get_commit()
509 assert tip.idx == 9
509 assert tip.idx == 9
510
510
511 @pytest.mark.backends("git", "hg")
511 @pytest.mark.backends("git", "hg")
512 def test_strip_multiple_commits(self):
512 def test_strip_multiple_commits(self):
513 tip = self.repo.get_commit()
513 tip = self.repo.get_commit()
514 assert tip.idx == 10
514 assert tip.idx == 10
515
515
516 old = self.repo.get_commit(commit_idx=5)
516 old = self.repo.get_commit(commit_idx=5)
517 self.repo.strip(old.raw_id, self.repo.DEFAULT_BRANCH_NAME)
517 self.repo.strip(old.raw_id, self.repo.DEFAULT_BRANCH_NAME)
518
518
519 tip = self.repo.get_commit()
519 tip = self.repo.get_commit()
520 assert tip.idx == 4
520 assert tip.idx == 4
521
521
522
522
523 @pytest.mark.backends('hg', 'git')
523 @pytest.mark.backends('hg', 'git')
524 class TestRepositoryPull(object):
524 class TestRepositoryPull(object):
525
525
526 def test_pull(self, vcsbackend):
526 def test_pull(self, vcsbackend):
527 source_repo = vcsbackend.repo
527 source_repo = vcsbackend.repo
528 target_repo = vcsbackend.create_repo()
528 target_repo = vcsbackend.create_repo()
529 assert len(source_repo.commit_ids) > len(target_repo.commit_ids)
529 assert len(source_repo.commit_ids) > len(target_repo.commit_ids)
530
530
531 target_repo.pull(source_repo.path)
531 target_repo.pull(source_repo.path)
532 # Note: Get a fresh instance, avoids caching trouble
532 # Note: Get a fresh instance, avoids caching trouble
533 target_repo = vcsbackend.backend(target_repo.path)
533 target_repo = vcsbackend.backend(target_repo.path)
534 assert len(source_repo.commit_ids) == len(target_repo.commit_ids)
534 assert len(source_repo.commit_ids) == len(target_repo.commit_ids)
535
535
536 def test_pull_wrong_path(self, vcsbackend):
536 def test_pull_wrong_path(self, vcsbackend):
537 target_repo = vcsbackend.create_repo()
537 target_repo = vcsbackend.create_repo()
538 with pytest.raises(RepositoryError):
538 with pytest.raises(RepositoryError):
539 target_repo.pull(target_repo.path + "wrong")
539 target_repo.pull(target_repo.path + "wrong")
540
540
541 def test_pull_specific_commits(self, vcsbackend):
541 def test_pull_specific_commits(self, vcsbackend):
542 source_repo = vcsbackend.repo
542 source_repo = vcsbackend.repo
543 target_repo = vcsbackend.create_repo()
543 target_repo = vcsbackend.create_repo()
544
544
545 second_commit = source_repo[1].raw_id
545 second_commit = source_repo[1].raw_id
546 if vcsbackend.alias == 'git':
546 if vcsbackend.alias == 'git':
547 second_commit_ref = 'refs/test-refs/a'
547 second_commit_ref = 'refs/test-refs/a'
548 source_repo.set_refs(second_commit_ref, second_commit)
548 source_repo.set_refs(second_commit_ref, second_commit)
549
549
550 target_repo.pull(source_repo.path, commit_ids=[second_commit])
550 target_repo.pull(source_repo.path, commit_ids=[second_commit])
551 target_repo = vcsbackend.backend(target_repo.path)
551 target_repo = vcsbackend.backend(target_repo.path)
552 assert 2 == len(target_repo.commit_ids)
552 assert 2 == len(target_repo.commit_ids)
553 assert second_commit == target_repo.get_commit().raw_id
553 assert second_commit == target_repo.get_commit().raw_id
General Comments 0
You need to be logged in to leave comments. Login now