##// END OF EJS Templates
Added recursive scanning for repositories in directory
marcink -
r877:bc9a73ad beta
parent child Browse files
Show More
@@ -1,614 +1,621
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 rhodecode.lib.utils
4 4 ~~~~~~~~~~~~~~~~~~~
5 5
6 6 Utilities library for RhodeCode
7 7
8 8 :created_on: Apr 18, 2010
9 9 :author: marcink
10 10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 11 :license: GPLv3, see COPYING for more details.
12 12 """
13 13 # This program is free software; you can redistribute it and/or
14 14 # modify it under the terms of the GNU General Public License
15 15 # as published by the Free Software Foundation; version 2
16 16 # of the License or (at your opinion) any later version of the license.
17 17 #
18 18 # This program is distributed in the hope that it will be useful,
19 19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 21 # GNU General Public License for more details.
22 22 #
23 23 # You should have received a copy of the GNU General Public License
24 24 # along with this program; if not, write to the Free Software
25 25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
26 26 # MA 02110-1301, USA.
27 27
28 28 import os
29 29 import logging
30 30 import datetime
31 31 import traceback
32 32
33 33 from UserDict import DictMixin
34 34
35 35 from mercurial import ui, config, hg
36 36 from mercurial.error import RepoError
37 37
38 38 import paste
39 39 import beaker
40 40 from paste.script.command import Command, BadCommand
41 41
42 42 from vcs.backends.base import BaseChangeset
43 43 from vcs.utils.lazy import LazyProperty
44 44
45 45 from rhodecode.model import meta
46 46 from rhodecode.model.caching_query import FromCache
47 47 from rhodecode.model.db import Repository, User, RhodeCodeUi, UserLog
48 48 from rhodecode.model.repo import RepoModel
49 49 from rhodecode.model.user import UserModel
50 50
51 51 log = logging.getLogger(__name__)
52 52
53 53
54 54 def get_repo_slug(request):
55 55 return request.environ['pylons.routes_dict'].get('repo_name')
56 56
57 57 def action_logger(user, action, repo, ipaddr='', sa=None):
58 58 """
59 59 Action logger for various actions made by users
60 60
61 61 :param user: user that made this action, can be a unique username string or
62 62 object containing user_id attribute
63 63 :param action: action to log, should be on of predefined unique actions for
64 64 easy translations
65 65 :param repo: string name of repository or object containing repo_id,
66 66 that action was made on
67 67 :param ipaddr: optional ip address from what the action was made
68 68 :param sa: optional sqlalchemy session
69 69
70 70 """
71 71
72 72 if not sa:
73 73 sa = meta.Session()
74 74
75 75 try:
76 76 um = UserModel()
77 77 if hasattr(user, 'user_id'):
78 78 user_obj = user
79 79 elif isinstance(user, basestring):
80 80 user_obj = um.get_by_username(user, cache=False)
81 81 else:
82 82 raise Exception('You have to provide user object or username')
83 83
84 84
85 85 rm = RepoModel()
86 86 if hasattr(repo, 'repo_id'):
87 87 repo_obj = rm.get(repo.repo_id, cache=False)
88 88 repo_name = repo_obj.repo_name
89 89 elif isinstance(repo, basestring):
90 90 repo_name = repo.lstrip('/')
91 91 repo_obj = rm.get_by_repo_name(repo_name, cache=False)
92 92 else:
93 93 raise Exception('You have to provide repository to action logger')
94 94
95 95
96 96 user_log = UserLog()
97 97 user_log.user_id = user_obj.user_id
98 98 user_log.action = action
99 99
100 100 user_log.repository_id = repo_obj.repo_id
101 101 user_log.repository_name = repo_name
102 102
103 103 user_log.action_date = datetime.datetime.now()
104 104 user_log.user_ip = ipaddr
105 105 sa.add(user_log)
106 106 sa.commit()
107 107
108 108 log.info('Adding user %s, action %s on %s', user_obj, action, repo)
109 109 except:
110 110 log.error(traceback.format_exc())
111 111 sa.rollback()
112 112
113 def get_repos(path, recursive=False, initial=False):
113 def get_repos(path, recursive=False):
114 114 """
115 115 Scans given path for repos and return (name,(type,path)) tuple
116 116
117 :param prefix:
118 :param path:
119 :param recursive:
120 :param initial:
117 :param path: path to scann for repositories
118 :param recursive: recursive search and return names with subdirs in front
121 119 """
122 120 from vcs.utils.helpers import get_scm
123 121 from vcs.exceptions import VCSError
124 122
123 if path.endswith('/'):
124 #add ending slash for better results
125 path = path[:-1]
126
127 def _get_repos(p):
128 for dirpath in os.listdir(p):
129 if os.path.isfile(os.path.join(p, dirpath)):
130 continue
131 cur_path = os.path.join(p, dirpath)
125 132 try:
126 scm = get_scm(path)
127 except:
128 pass
129 else:
130 raise Exception('The given path %s should not be a repository got %s',
131 path, scm)
133 scm_info = get_scm(cur_path)
134 yield scm_info[1].split(path)[-1].lstrip('/'), scm_info
135 except VCSError:
136 if not recursive:
137 continue
138 #check if this dir containts other repos for recursive scan
139 rec_path = os.path.join(p, dirpath)
140 if os.path.isdir(rec_path):
141 for inner_scm in _get_repos(rec_path):
142 yield inner_scm
132 143
133 for dirpath in os.listdir(path):
134 try:
135 yield dirpath, get_scm(os.path.join(path, dirpath))
136 except VCSError:
137 pass
144 return _get_repos(path)
138 145
139 146 def check_repo_fast(repo_name, base_path):
140 147 """
141 148 Check given path for existence of directory
142 149 :param repo_name:
143 150 :param base_path:
144 151
145 152 :return False: if this directory is present
146 153 """
147 154 if os.path.isdir(os.path.join(base_path, repo_name)):return False
148 155 return True
149 156
150 157 def check_repo(repo_name, base_path, verify=True):
151 158
152 159 repo_path = os.path.join(base_path, repo_name)
153 160
154 161 try:
155 162 if not check_repo_fast(repo_name, base_path):
156 163 return False
157 164 r = hg.repository(ui.ui(), repo_path)
158 165 if verify:
159 166 hg.verify(r)
160 167 #here we hnow that repo exists it was verified
161 168 log.info('%s repo is already created', repo_name)
162 169 return False
163 170 except RepoError:
164 171 #it means that there is no valid repo there...
165 172 log.info('%s repo is free for creation', repo_name)
166 173 return True
167 174
168 175 def ask_ok(prompt, retries=4, complaint='Yes or no, please!'):
169 176 while True:
170 177 ok = raw_input(prompt)
171 178 if ok in ('y', 'ye', 'yes'): return True
172 179 if ok in ('n', 'no', 'nop', 'nope'): return False
173 180 retries = retries - 1
174 181 if retries < 0: raise IOError
175 182 print complaint
176 183
177 184 #propagated from mercurial documentation
178 185 ui_sections = ['alias', 'auth',
179 186 'decode/encode', 'defaults',
180 187 'diff', 'email',
181 188 'extensions', 'format',
182 189 'merge-patterns', 'merge-tools',
183 190 'hooks', 'http_proxy',
184 191 'smtp', 'patch',
185 192 'paths', 'profiling',
186 193 'server', 'trusted',
187 194 'ui', 'web', ]
188 195
189 196 def make_ui(read_from='file', path=None, checkpaths=True):
190 197 """
191 198 A function that will read python rc files or database
192 199 and make an mercurial ui object from read options
193 200
194 201 :param path: path to mercurial config file
195 202 :param checkpaths: check the path
196 203 :param read_from: read from 'file' or 'db'
197 204 """
198 205
199 206 baseui = ui.ui()
200 207
201 208 #clean the baseui object
202 209 baseui._ocfg = config.config()
203 210 baseui._ucfg = config.config()
204 211 baseui._tcfg = config.config()
205 212
206 213 if read_from == 'file':
207 214 if not os.path.isfile(path):
208 215 log.warning('Unable to read config file %s' % path)
209 216 return False
210 217 log.debug('reading hgrc from %s', path)
211 218 cfg = config.config()
212 219 cfg.read(path)
213 220 for section in ui_sections:
214 221 for k, v in cfg.items(section):
215 222 log.debug('settings ui from file[%s]%s:%s', section, k, v)
216 223 baseui.setconfig(section, k, v)
217 224
218 225
219 226 elif read_from == 'db':
220 227 sa = meta.Session()
221 228 ret = sa.query(RhodeCodeUi)\
222 229 .options(FromCache("sql_cache_short",
223 230 "get_hg_ui_settings")).all()
224 231
225 232 hg_ui = ret
226 233 for ui_ in hg_ui:
227 234 if ui_.ui_active:
228 235 log.debug('settings ui from db[%s]%s:%s', ui_.ui_section,
229 236 ui_.ui_key, ui_.ui_value)
230 237 baseui.setconfig(ui_.ui_section, ui_.ui_key, ui_.ui_value)
231 238
232 239 meta.Session.remove()
233 240 return baseui
234 241
235 242
236 243 def set_rhodecode_config(config):
237 244 """Updates pylons config with new settings from database
238 245
239 246 :param config:
240 247 """
241 248 from rhodecode.model.settings import SettingsModel
242 249 hgsettings = SettingsModel().get_app_settings()
243 250
244 251 for k, v in hgsettings.items():
245 252 config[k] = v
246 253
247 254 def invalidate_cache(cache_key, *args):
248 255 """Puts cache invalidation task into db for
249 256 further global cache invalidation
250 257 """
251 258
252 259 from rhodecode.model.scm import ScmModel
253 260
254 261 if cache_key.startswith('get_repo_cached_'):
255 262 name = cache_key.split('get_repo_cached_')[-1]
256 263 ScmModel().mark_for_invalidation(name)
257 264
258 265 class EmptyChangeset(BaseChangeset):
259 266 """
260 267 An dummy empty changeset. It's possible to pass hash when creating
261 268 an EmptyChangeset
262 269 """
263 270
264 271 def __init__(self, cs='0' * 40):
265 272 self._empty_cs = cs
266 273 self.revision = -1
267 274 self.message = ''
268 275 self.author = ''
269 276 self.date = ''
270 277
271 278 @LazyProperty
272 279 def raw_id(self):
273 280 """Returns raw string identifying this changeset, useful for web
274 281 representation.
275 282 """
276 283
277 284 return self._empty_cs
278 285
279 286 @LazyProperty
280 287 def short_id(self):
281 288 return self.raw_id[:12]
282 289
283 290 def get_file_changeset(self, path):
284 291 return self
285 292
286 293 def get_file_content(self, path):
287 294 return u''
288 295
289 296 def get_file_size(self, path):
290 297 return 0
291 298
292 299 def repo2db_mapper(initial_repo_list, remove_obsolete=False):
293 300 """maps all found repositories into db
294 301 """
295 302
296 303 sa = meta.Session()
297 304 rm = RepoModel()
298 305 user = sa.query(User).filter(User.admin == True).first()
299 306
300 307 for name, repo in initial_repo_list.items():
301 308 if not rm.get_by_repo_name(name, cache=False):
302 309 log.info('repository %s not found creating default', name)
303 310
304 311 form_data = {
305 312 'repo_name':name,
306 313 'repo_type':repo.alias,
307 314 'description':repo.description \
308 315 if repo.description != 'unknown' else \
309 316 '%s repository' % name,
310 317 'private':False
311 318 }
312 319 rm.create(form_data, user, just_db=True)
313 320
314 321 if remove_obsolete:
315 322 #remove from database those repositories that are not in the filesystem
316 323 for repo in sa.query(Repository).all():
317 324 if repo.repo_name not in initial_repo_list.keys():
318 325 sa.delete(repo)
319 326 sa.commit()
320 327
321 328 class OrderedDict(dict, DictMixin):
322 329
323 330 def __init__(self, *args, **kwds):
324 331 if len(args) > 1:
325 332 raise TypeError('expected at most 1 arguments, got %d' % len(args))
326 333 try:
327 334 self.__end
328 335 except AttributeError:
329 336 self.clear()
330 337 self.update(*args, **kwds)
331 338
332 339 def clear(self):
333 340 self.__end = end = []
334 341 end += [None, end, end] # sentinel node for doubly linked list
335 342 self.__map = {} # key --> [key, prev, next]
336 343 dict.clear(self)
337 344
338 345 def __setitem__(self, key, value):
339 346 if key not in self:
340 347 end = self.__end
341 348 curr = end[1]
342 349 curr[2] = end[1] = self.__map[key] = [key, curr, end]
343 350 dict.__setitem__(self, key, value)
344 351
345 352 def __delitem__(self, key):
346 353 dict.__delitem__(self, key)
347 354 key, prev, next = self.__map.pop(key)
348 355 prev[2] = next
349 356 next[1] = prev
350 357
351 358 def __iter__(self):
352 359 end = self.__end
353 360 curr = end[2]
354 361 while curr is not end:
355 362 yield curr[0]
356 363 curr = curr[2]
357 364
358 365 def __reversed__(self):
359 366 end = self.__end
360 367 curr = end[1]
361 368 while curr is not end:
362 369 yield curr[0]
363 370 curr = curr[1]
364 371
365 372 def popitem(self, last=True):
366 373 if not self:
367 374 raise KeyError('dictionary is empty')
368 375 if last:
369 376 key = reversed(self).next()
370 377 else:
371 378 key = iter(self).next()
372 379 value = self.pop(key)
373 380 return key, value
374 381
375 382 def __reduce__(self):
376 383 items = [[k, self[k]] for k in self]
377 384 tmp = self.__map, self.__end
378 385 del self.__map, self.__end
379 386 inst_dict = vars(self).copy()
380 387 self.__map, self.__end = tmp
381 388 if inst_dict:
382 389 return (self.__class__, (items,), inst_dict)
383 390 return self.__class__, (items,)
384 391
385 392 def keys(self):
386 393 return list(self)
387 394
388 395 setdefault = DictMixin.setdefault
389 396 update = DictMixin.update
390 397 pop = DictMixin.pop
391 398 values = DictMixin.values
392 399 items = DictMixin.items
393 400 iterkeys = DictMixin.iterkeys
394 401 itervalues = DictMixin.itervalues
395 402 iteritems = DictMixin.iteritems
396 403
397 404 def __repr__(self):
398 405 if not self:
399 406 return '%s()' % (self.__class__.__name__,)
400 407 return '%s(%r)' % (self.__class__.__name__, self.items())
401 408
402 409 def copy(self):
403 410 return self.__class__(self)
404 411
405 412 @classmethod
406 413 def fromkeys(cls, iterable, value=None):
407 414 d = cls()
408 415 for key in iterable:
409 416 d[key] = value
410 417 return d
411 418
412 419 def __eq__(self, other):
413 420 if isinstance(other, OrderedDict):
414 421 return len(self) == len(other) and self.items() == other.items()
415 422 return dict.__eq__(self, other)
416 423
417 424 def __ne__(self, other):
418 425 return not self == other
419 426
420 427
421 428 #set cache regions for beaker so celery can utilise it
422 429 def add_cache(settings):
423 430 cache_settings = {'regions':None}
424 431 for key in settings.keys():
425 432 for prefix in ['beaker.cache.', 'cache.']:
426 433 if key.startswith(prefix):
427 434 name = key.split(prefix)[1].strip()
428 435 cache_settings[name] = settings[key].strip()
429 436 if cache_settings['regions']:
430 437 for region in cache_settings['regions'].split(','):
431 438 region = region.strip()
432 439 region_settings = {}
433 440 for key, value in cache_settings.items():
434 441 if key.startswith(region):
435 442 region_settings[key.split('.')[1]] = value
436 443 region_settings['expire'] = int(region_settings.get('expire',
437 444 60))
438 445 region_settings.setdefault('lock_dir',
439 446 cache_settings.get('lock_dir'))
440 447 if 'type' not in region_settings:
441 448 region_settings['type'] = cache_settings.get('type',
442 449 'memory')
443 450 beaker.cache.cache_regions[region] = region_settings
444 451
445 452 def get_current_revision():
446 453 """Returns tuple of (number, id) from repository containing this package
447 454 or None if repository could not be found.
448 455 """
449 456
450 457 try:
451 458 from vcs import get_repo
452 459 from vcs.utils.helpers import get_scm
453 460 from vcs.exceptions import RepositoryError, VCSError
454 461 repopath = os.path.join(os.path.dirname(__file__), '..', '..')
455 462 scm = get_scm(repopath)[0]
456 463 repo = get_repo(path=repopath, alias=scm)
457 464 tip = repo.get_changeset()
458 465 return (tip.revision, tip.short_id)
459 466 except (ImportError, RepositoryError, VCSError), err:
460 467 logging.debug("Cannot retrieve rhodecode's revision. Original error "
461 468 "was: %s" % err)
462 469 return None
463 470
464 471 #===============================================================================
465 472 # TEST FUNCTIONS AND CREATORS
466 473 #===============================================================================
467 474 def create_test_index(repo_location, full_index):
468 475 """Makes default test index
469 476 :param repo_location:
470 477 :param full_index:
471 478 """
472 479 from rhodecode.lib.indexers.daemon import WhooshIndexingDaemon
473 480 from rhodecode.lib.pidlock import DaemonLock, LockHeld
474 481 import shutil
475 482
476 483 index_location = os.path.join(repo_location, 'index')
477 484 if os.path.exists(index_location):
478 485 shutil.rmtree(index_location)
479 486
480 487 try:
481 488 l = DaemonLock()
482 489 WhooshIndexingDaemon(index_location=index_location,
483 490 repo_location=repo_location)\
484 491 .run(full_index=full_index)
485 492 l.release()
486 493 except LockHeld:
487 494 pass
488 495
489 496 def create_test_env(repos_test_path, config):
490 497 """Makes a fresh database and
491 498 install test repository into tmp dir
492 499 """
493 500 from rhodecode.lib.db_manage import DbManage
494 501 from rhodecode.tests import HG_REPO, GIT_REPO, NEW_HG_REPO, NEW_GIT_REPO, \
495 502 HG_FORK, GIT_FORK, TESTS_TMP_PATH
496 503 import tarfile
497 504 import shutil
498 505 from os.path import dirname as dn, join as jn, abspath
499 506
500 507 log = logging.getLogger('TestEnvCreator')
501 508 # create logger
502 509 log.setLevel(logging.DEBUG)
503 510 log.propagate = True
504 511 # create console handler and set level to debug
505 512 ch = logging.StreamHandler()
506 513 ch.setLevel(logging.DEBUG)
507 514
508 515 # create formatter
509 516 formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
510 517
511 518 # add formatter to ch
512 519 ch.setFormatter(formatter)
513 520
514 521 # add ch to logger
515 522 log.addHandler(ch)
516 523
517 524 #PART ONE create db
518 525 dbconf = config['sqlalchemy.db1.url']
519 526 log.debug('making test db %s', dbconf)
520 527
521 528 dbmanage = DbManage(log_sql=True, dbconf=dbconf, root=config['here'],
522 529 tests=True)
523 530 dbmanage.create_tables(override=True)
524 531 dbmanage.config_prompt(repos_test_path)
525 532 dbmanage.create_default_user()
526 533 dbmanage.admin_prompt()
527 534 dbmanage.create_permissions()
528 535 dbmanage.populate_default_permissions()
529 536
530 537 #PART TWO make test repo
531 538 log.debug('making test vcs repositories')
532 539
533 540 #remove old one from previos tests
534 541 for r in [HG_REPO, GIT_REPO, NEW_HG_REPO, NEW_GIT_REPO, HG_FORK, GIT_FORK]:
535 542
536 543 if os.path.isdir(jn(TESTS_TMP_PATH, r)):
537 544 log.debug('removing %s', r)
538 545 shutil.rmtree(jn(TESTS_TMP_PATH, r))
539 546
540 547 #CREATE DEFAULT HG REPOSITORY
541 548 cur_dir = dn(dn(abspath(__file__)))
542 549 tar = tarfile.open(jn(cur_dir, 'tests', "vcs_test_hg.tar.gz"))
543 550 tar.extractall(jn(TESTS_TMP_PATH, HG_REPO))
544 551 tar.close()
545 552
546 553
547 554 #==============================================================================
548 555 # PASTER COMMANDS
549 556 #==============================================================================
550 557
551 558 class BasePasterCommand(Command):
552 559 """
553 560 Abstract Base Class for paster commands.
554 561
555 562 The celery commands are somewhat aggressive about loading
556 563 celery.conf, and since our module sets the `CELERY_LOADER`
557 564 environment variable to our loader, we have to bootstrap a bit and
558 565 make sure we've had a chance to load the pylons config off of the
559 566 command line, otherwise everything fails.
560 567 """
561 568 min_args = 1
562 569 min_args_error = "Please provide a paster config file as an argument."
563 570 takes_config_file = 1
564 571 requires_config_file = True
565 572
566 573 def notify_msg(self, msg, log=False):
567 574 """Make a notification to user, additionally if logger is passed
568 575 it logs this action using given logger
569 576
570 577 :param msg: message that will be printed to user
571 578 :param log: logging instance, to use to additionally log this message
572 579
573 580 """
574 581 print msg
575 582 if log and isinstance(log, logging):
576 583 log(msg)
577 584
578 585
579 586 def run(self, args):
580 587 """
581 588 Overrides Command.run
582 589
583 590 Checks for a config file argument and loads it.
584 591 """
585 592 if len(args) < self.min_args:
586 593 raise BadCommand(
587 594 self.min_args_error % {'min_args': self.min_args,
588 595 'actual_args': len(args)})
589 596
590 597 # Decrement because we're going to lob off the first argument.
591 598 # @@ This is hacky
592 599 self.min_args -= 1
593 600 self.bootstrap_config(args[0])
594 601 self.update_parser()
595 602 return super(BasePasterCommand, self).run(args[1:])
596 603
597 604 def update_parser(self):
598 605 """
599 606 Abstract method. Allows for the class's parser to be updated
600 607 before the superclass's `run` method is called. Necessary to
601 608 allow options/arguments to be passed through to the underlying
602 609 celery command.
603 610 """
604 611 raise NotImplementedError("Abstract Method.")
605 612
606 613 def bootstrap_config(self, conf):
607 614 """
608 615 Loads the pylons configuration.
609 616 """
610 617 from pylons import config as pylonsconfig
611 618
612 619 path_to_ini_file = os.path.realpath(conf)
613 620 conf = paste.deploy.appconfig('config:' + path_to_ini_file)
614 621 pylonsconfig.init_app(conf.global_conf, conf.local_conf)
@@ -1,377 +1,377
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 rhodecode.model.scm
4 4 ~~~~~~~~~~~~~~~~~~~
5 5
6 6 Scm model for RhodeCode
7 7
8 8 :created_on: Apr 9, 2010
9 9 :author: marcink
10 10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 11 :license: GPLv3, see COPYING for more details.
12 12 """
13 13 # This program is free software; you can redistribute it and/or
14 14 # modify it under the terms of the GNU General Public License
15 15 # as published by the Free Software Foundation; version 2
16 16 # of the License or (at your opinion) any later version of the license.
17 17 #
18 18 # This program is distributed in the hope that it will be useful,
19 19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 21 # GNU General Public License for more details.
22 22 #
23 23 # You should have received a copy of the GNU General Public License
24 24 # along with this program; if not, write to the Free Software
25 25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
26 26 # MA 02110-1301, USA.
27 27 import os
28 28 import time
29 29 import traceback
30 30 import logging
31 31
32 32 from vcs import get_backend
33 33 from vcs.utils.helpers import get_scm
34 34 from vcs.exceptions import RepositoryError, VCSError
35 35 from vcs.utils.lazy import LazyProperty
36 36
37 37 from mercurial import ui
38 38
39 39 from beaker.cache import cache_region, region_invalidate
40 40
41 41 from rhodecode import BACKENDS
42 42 from rhodecode.lib import helpers as h
43 43 from rhodecode.lib.auth import HasRepoPermissionAny
44 from rhodecode.lib.utils import get_repos, make_ui, action_logger
44 from rhodecode.lib.utils import get_repos as get_filesystem_repos, make_ui, action_logger
45 45 from rhodecode.model import BaseModel
46 46 from rhodecode.model.user import UserModel
47 47
48 48 from rhodecode.model.db import Repository, RhodeCodeUi, CacheInvalidation, \
49 49 UserFollowing, UserLog
50 50 from rhodecode.model.caching_query import FromCache
51 51
52 52 from sqlalchemy.orm import joinedload
53 53 from sqlalchemy.orm.session import make_transient
54 54 from sqlalchemy.exc import DatabaseError
55 55
56 56 log = logging.getLogger(__name__)
57 57
58 58
59 59 class UserTemp(object):
60 60 def __init__(self, user_id):
61 61 self.user_id = user_id
62 62 class RepoTemp(object):
63 63 def __init__(self, repo_id):
64 64 self.repo_id = repo_id
65 65
66 66 class ScmModel(BaseModel):
67 67 """Generic Scm Model
68 68 """
69 69
70 70 @LazyProperty
71 71 def repos_path(self):
72 72 """Get's the repositories root path from database
73 73 """
74 74
75 75 q = self.sa.query(RhodeCodeUi).filter(RhodeCodeUi.ui_key == '/').one()
76 76
77 77 return q.ui_value
78 78
79 79 def repo_scan(self, repos_path, baseui):
80 80 """Listing of repositories in given path. This path should not be a
81 81 repository itself. Return a dictionary of repository objects
82 82
83 83 :param repos_path: path to directory containing repositories
84 84 :param baseui: baseui instance to instantiate MercurialRepostitory with
85 85 """
86 86
87 87 log.info('scanning for repositories in %s', repos_path)
88 88
89 89 if not isinstance(baseui, ui.ui):
90 90 baseui = make_ui('db')
91 91 repos_list = {}
92 92
93 for name, path in get_repos(repos_path):
93 for name, path in get_filesystem_repos(repos_path, recursive=True):
94 94 try:
95 95 if repos_list.has_key(name):
96 96 raise RepositoryError('Duplicate repository name %s '
97 97 'found in %s' % (name, path))
98 98 else:
99 99
100 100 klass = get_backend(path[0])
101 101
102 102 if path[0] == 'hg' and path[0] in BACKENDS.keys():
103 103 repos_list[name] = klass(path[1], baseui=baseui)
104 104
105 105 if path[0] == 'git' and path[0] in BACKENDS.keys():
106 106 repos_list[name] = klass(path[1])
107 107 except OSError:
108 108 continue
109 109
110 110 return repos_list
111 111
112 112 def get_repos(self, all_repos=None):
113 113 """Get all repos from db and for each repo create it's backend instance.
114 114 and fill that backed with information from database
115 115
116 116 :param all_repos: give specific repositories list, good for filtering
117 117 """
118 118
119 119 if all_repos is None:
120 120 all_repos = self.sa.query(Repository)\
121 121 .order_by(Repository.repo_name).all()
122 122
123 123 #get the repositories that should be invalidated
124 124 invalidation_list = [str(x.cache_key) for x in \
125 125 self.sa.query(CacheInvalidation.cache_key)\
126 126 .filter(CacheInvalidation.cache_active == False)\
127 127 .all()]
128 128
129 129 for r in all_repos:
130 130
131 131 repo = self.get(r.repo_name, invalidation_list)
132 132
133 133 if repo is not None:
134 134 last_change = repo.last_change
135 135 tip = h.get_changeset_safe(repo, 'tip')
136 136
137 137 tmp_d = {}
138 138 tmp_d['name'] = repo.name
139 139 tmp_d['name_sort'] = tmp_d['name'].lower()
140 140 tmp_d['description'] = repo.dbrepo.description
141 141 tmp_d['description_sort'] = tmp_d['description']
142 142 tmp_d['last_change'] = last_change
143 143 tmp_d['last_change_sort'] = time.mktime(last_change.timetuple())
144 144 tmp_d['tip'] = tip.raw_id
145 145 tmp_d['tip_sort'] = tip.revision
146 146 tmp_d['rev'] = tip.revision
147 147 tmp_d['contact'] = repo.dbrepo.user.full_contact
148 148 tmp_d['contact_sort'] = tmp_d['contact']
149 149 tmp_d['repo_archives'] = list(repo._get_archives())
150 150 tmp_d['last_msg'] = tip.message
151 151 tmp_d['repo'] = repo
152 152 yield tmp_d
153 153
154 154 def get_repo(self, repo_name):
155 155 return self.get(repo_name)
156 156
157 157 def get(self, repo_name, invalidation_list=None):
158 158 """Get's repository from given name, creates BackendInstance and
159 159 propagates it's data from database with all additional information
160 160
161 161 :param repo_name:
162 162 :param invalidation_list: if a invalidation list is given the get
163 163 method should not manually check if this repository needs
164 164 invalidation and just invalidate the repositories in list
165 165
166 166 """
167 167 if not HasRepoPermissionAny('repository.read', 'repository.write',
168 168 'repository.admin')(repo_name, 'get repo check'):
169 169 return
170 170
171 171 #======================================================================
172 172 # CACHE FUNCTION
173 173 #======================================================================
174 174 @cache_region('long_term')
175 175 def _get_repo(repo_name):
176 176
177 177 repo_path = os.path.join(self.repos_path, repo_name)
178 178
179 179 try:
180 180 alias = get_scm(repo_path)[0]
181 181
182 182 log.debug('Creating instance of %s repository', alias)
183 183 backend = get_backend(alias)
184 184 except VCSError:
185 185 log.error(traceback.format_exc())
186 186 return
187 187
188 188 if alias == 'hg':
189 189 from pylons import app_globals as g
190 190 repo = backend(repo_path, create=False, baseui=g.baseui)
191 191 #skip hidden web repository
192 192 if repo._get_hidden():
193 193 return
194 194 else:
195 195 repo = backend(repo_path, create=False)
196 196
197 197 dbrepo = self.sa.query(Repository)\
198 198 .options(joinedload(Repository.fork))\
199 199 .options(joinedload(Repository.user))\
200 200 .filter(Repository.repo_name == repo_name)\
201 201 .scalar()
202 202
203 203 make_transient(dbrepo)
204 204 if dbrepo.user:
205 205 make_transient(dbrepo.user)
206 206 if dbrepo.fork:
207 207 make_transient(dbrepo.fork)
208 208
209 209 repo.dbrepo = dbrepo
210 210 return repo
211 211
212 212 pre_invalidate = True
213 213 if invalidation_list is not None:
214 214 pre_invalidate = repo_name in invalidation_list
215 215
216 216 if pre_invalidate:
217 217 invalidate = self._should_invalidate(repo_name)
218 218
219 219 if invalidate:
220 220 log.info('invalidating cache for repository %s', repo_name)
221 221 region_invalidate(_get_repo, None, repo_name)
222 222 self._mark_invalidated(invalidate)
223 223
224 224 return _get_repo(repo_name)
225 225
226 226
227 227
228 228 def mark_for_invalidation(self, repo_name):
229 229 """Puts cache invalidation task into db for
230 230 further global cache invalidation
231 231
232 232 :param repo_name: this repo that should invalidation take place
233 233 """
234 234
235 235 log.debug('marking %s for invalidation', repo_name)
236 236 cache = self.sa.query(CacheInvalidation)\
237 237 .filter(CacheInvalidation.cache_key == repo_name).scalar()
238 238
239 239 if cache:
240 240 #mark this cache as inactive
241 241 cache.cache_active = False
242 242 else:
243 243 log.debug('cache key not found in invalidation db -> creating one')
244 244 cache = CacheInvalidation(repo_name)
245 245
246 246 try:
247 247 self.sa.add(cache)
248 248 self.sa.commit()
249 249 except (DatabaseError,):
250 250 log.error(traceback.format_exc())
251 251 self.sa.rollback()
252 252
253 253
254 254 def toggle_following_repo(self, follow_repo_id, user_id):
255 255
256 256 f = self.sa.query(UserFollowing)\
257 257 .filter(UserFollowing.follows_repo_id == follow_repo_id)\
258 258 .filter(UserFollowing.user_id == user_id).scalar()
259 259
260 260 if f is not None:
261 261
262 262 try:
263 263 self.sa.delete(f)
264 264 self.sa.commit()
265 265 action_logger(UserTemp(user_id),
266 266 'stopped_following_repo',
267 267 RepoTemp(follow_repo_id))
268 268 return
269 269 except:
270 270 log.error(traceback.format_exc())
271 271 self.sa.rollback()
272 272 raise
273 273
274 274
275 275 try:
276 276 f = UserFollowing()
277 277 f.user_id = user_id
278 278 f.follows_repo_id = follow_repo_id
279 279 self.sa.add(f)
280 280 self.sa.commit()
281 281 action_logger(UserTemp(user_id),
282 282 'started_following_repo',
283 283 RepoTemp(follow_repo_id))
284 284 except:
285 285 log.error(traceback.format_exc())
286 286 self.sa.rollback()
287 287 raise
288 288
289 289 def toggle_following_user(self, follow_user_id , user_id):
290 290 f = self.sa.query(UserFollowing)\
291 291 .filter(UserFollowing.follows_user_id == follow_user_id)\
292 292 .filter(UserFollowing.user_id == user_id).scalar()
293 293
294 294 if f is not None:
295 295 try:
296 296 self.sa.delete(f)
297 297 self.sa.commit()
298 298 return
299 299 except:
300 300 log.error(traceback.format_exc())
301 301 self.sa.rollback()
302 302 raise
303 303
304 304 try:
305 305 f = UserFollowing()
306 306 f.user_id = user_id
307 307 f.follows_user_id = follow_user_id
308 308 self.sa.add(f)
309 309 self.sa.commit()
310 310 except:
311 311 log.error(traceback.format_exc())
312 312 self.sa.rollback()
313 313 raise
314 314
315 315 def is_following_repo(self, repo_name, user_id):
316 316 r = self.sa.query(Repository)\
317 317 .filter(Repository.repo_name == repo_name).scalar()
318 318
319 319 f = self.sa.query(UserFollowing)\
320 320 .filter(UserFollowing.follows_repository == r)\
321 321 .filter(UserFollowing.user_id == user_id).scalar()
322 322
323 323 return f is not None
324 324
325 325 def is_following_user(self, username, user_id):
326 326 u = UserModel(self.sa).get_by_username(username)
327 327
328 328 f = self.sa.query(UserFollowing)\
329 329 .filter(UserFollowing.follows_user == u)\
330 330 .filter(UserFollowing.user_id == user_id).scalar()
331 331
332 332 return f is not None
333 333
334 334 def get_followers(self, repo_id):
335 335 return self.sa.query(UserFollowing)\
336 336 .filter(UserFollowing.follows_repo_id == repo_id).count()
337 337
338 338 def get_forks(self, repo_id):
339 339 return self.sa.query(Repository)\
340 340 .filter(Repository.fork_id == repo_id).count()
341 341
342 342
343 343 def get_unread_journal(self):
344 344 return self.sa.query(UserLog).count()
345 345
346 346
347 347 def _should_invalidate(self, repo_name):
348 348 """Looks up database for invalidation signals for this repo_name
349 349
350 350 :param repo_name:
351 351 """
352 352
353 353 ret = self.sa.query(CacheInvalidation)\
354 354 .options(FromCache('sql_cache_short',
355 355 'get_invalidation_%s' % repo_name))\
356 356 .filter(CacheInvalidation.cache_key == repo_name)\
357 357 .filter(CacheInvalidation.cache_active == False)\
358 358 .scalar()
359 359
360 360 return ret
361 361
362 362 def _mark_invalidated(self, cache_key):
363 363 """ Marks all occurences of cache to invaldation as already invalidated
364 364
365 365 :param cache_key:
366 366 """
367 367
368 368 if cache_key:
369 369 log.debug('marking %s as already invalidated', cache_key)
370 370 try:
371 371 cache_key.cache_active = True
372 372 self.sa.add(cache_key)
373 373 self.sa.commit()
374 374 except (DatabaseError,):
375 375 log.error(traceback.format_exc())
376 376 self.sa.rollback()
377 377
General Comments 0
You need to be logged in to leave comments. Login now