Show More
utils.py
688 lines
| 21.8 KiB
| text/x-python
|
PythonLexer
r783 | # -*- coding: utf-8 -*- | |||
""" | ||||
r833 | rhodecode.lib.utils | |||
~~~~~~~~~~~~~~~~~~~ | ||||
r783 | ||||
Utilities library for RhodeCode | ||||
:created_on: Apr 18, 2010 | ||||
:author: marcink | ||||
r902 | :copyright: (C) 2009-2011 Marcin Kuzminski <marcin@python-works.com> | |||
r783 | :license: GPLv3, see COPYING for more details. | |||
""" | ||||
r547 | # This program is free software; you can redistribute it and/or | |||
# modify it under the terms of the GNU General Public License | ||||
# as published by the Free Software Foundation; version 2 | ||||
# of the License or (at your opinion) any later version of the license. | ||||
# | ||||
# This program is distributed in the hope that it will be useful, | ||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||||
# GNU General Public License for more details. | ||||
# | ||||
# You should have received a copy of the GNU General Public License | ||||
# along with this program; if not, write to the Free Software | ||||
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, | ||||
# MA 02110-1301, USA. | ||||
r783 | ||||
import os | ||||
import logging | ||||
import datetime | ||||
import traceback | ||||
r1022 | import paste | |||
import beaker | ||||
from paste.script.command import Command, BadCommand | ||||
r633 | ||||
r631 | from UserDict import DictMixin | |||
r783 | ||||
r631 | from mercurial import ui, config, hg | |||
from mercurial.error import RepoError | ||||
r783 | ||||
r1022 | from webhelpers.text import collapse, remove_formatting, strip_tags | |||
r785 | ||||
r783 | from vcs.backends.base import BaseChangeset | |||
from vcs.utils.lazy import LazyProperty | ||||
r631 | from rhodecode.model import meta | |||
from rhodecode.model.caching_query import FromCache | ||||
r878 | from rhodecode.model.db import Repository, User, RhodeCodeUi, UserLog, Group | |||
r631 | from rhodecode.model.repo import RepoModel | |||
from rhodecode.model.user import UserModel | ||||
r756 | ||||
r547 | log = logging.getLogger(__name__) | |||
r1022 | def recursive_replace(str, replace=' '): | |||
"""Recursive replace of given sign to just one instance | ||||
:param str: given string | ||||
:param replace: char to find and replace multiple instances | ||||
Examples:: | ||||
>>> recursive_replace("Mighty---Mighty-Bo--sstones",'-') | ||||
'Mighty-Mighty-Bo-sstones' | ||||
""" | ||||
if str.find(replace * 2) == -1: | ||||
return str | ||||
else: | ||||
str = str.replace(replace * 2, replace) | ||||
return recursive_replace(str, replace) | ||||
def repo_name_slug(value): | ||||
"""Return slug of name of repository | ||||
This function is called on each creation/modification | ||||
of repository to prevent bad names in repo | ||||
""" | ||||
slug = remove_formatting(value) | ||||
slug = strip_tags(slug) | ||||
for c in """=[]\;'"<>,/~!@#$%^&*()+{}|: """: | ||||
slug = slug.replace(c, '-') | ||||
slug = recursive_replace(slug, '-') | ||||
slug = collapse(slug, '-') | ||||
return slug | ||||
r604 | def get_repo_slug(request): | |||
r547 | return request.environ['pylons.routes_dict'].get('repo_name') | |||
r689 | def action_logger(user, action, repo, ipaddr='', sa=None): | |||
r547 | """ | |||
r735 | Action logger for various actions made by users | |||
r689 | ||||
r735 | :param user: user that made this action, can be a unique username string or | |||
r689 | object containing user_id attribute | |||
:param action: action to log, should be on of predefined unique actions for | ||||
easy translations | ||||
r735 | :param repo: string name of repository or object containing repo_id, | |||
that action was made on | ||||
r689 | :param ipaddr: optional ip address from what the action was made | |||
:param sa: optional sqlalchemy session | ||||
r547 | """ | |||
r604 | ||||
r547 | if not sa: | |||
r629 | sa = meta.Session() | |||
r604 | ||||
r547 | try: | |||
r735 | um = UserModel() | |||
r547 | if hasattr(user, 'user_id'): | |||
r654 | user_obj = user | |||
r547 | elif isinstance(user, basestring): | |||
r735 | user_obj = um.get_by_username(user, cache=False) | |||
r547 | else: | |||
raise Exception('You have to provide user object or username') | ||||
r604 | ||||
r689 | ||||
r735 | rm = RepoModel() | |||
if hasattr(repo, 'repo_id'): | ||||
repo_obj = rm.get(repo.repo_id, cache=False) | ||||
repo_name = repo_obj.repo_name | ||||
elif isinstance(repo, basestring): | ||||
r689 | repo_name = repo.lstrip('/') | |||
r735 | repo_obj = rm.get_by_repo_name(repo_name, cache=False) | |||
r689 | else: | |||
raise Exception('You have to provide repository to action logger') | ||||
r547 | user_log = UserLog() | |||
r654 | user_log.user_id = user_obj.user_id | |||
r547 | user_log.action = action | |||
r756 | ||||
r735 | user_log.repository_id = repo_obj.repo_id | |||
r547 | user_log.repository_name = repo_name | |||
r756 | ||||
r547 | user_log.action_date = datetime.datetime.now() | |||
user_log.user_ip = ipaddr | ||||
sa.add(user_log) | ||||
sa.commit() | ||||
r621 | ||||
r756 | log.info('Adding user %s, action %s on %s', user_obj, action, repo) | |||
r654 | except: | |||
log.error(traceback.format_exc()) | ||||
r547 | sa.rollback() | |||
r604 | ||||
r877 | def get_repos(path, recursive=False): | |||
r631 | """ | |||
Scans given path for repos and return (name,(type,path)) tuple | ||||
r860 | ||||
r877 | :param path: path to scann for repositories | |||
:param recursive: recursive search and return names with subdirs in front | ||||
r631 | """ | |||
from vcs.utils.helpers import get_scm | ||||
from vcs.exceptions import VCSError | ||||
r633 | ||||
r877 | if path.endswith('/'): | |||
#add ending slash for better results | ||||
path = path[:-1] | ||||
r631 | ||||
r877 | def _get_repos(p): | |||
for dirpath in os.listdir(p): | ||||
if os.path.isfile(os.path.join(p, dirpath)): | ||||
continue | ||||
cur_path = os.path.join(p, dirpath) | ||||
try: | ||||
scm_info = get_scm(cur_path) | ||||
yield scm_info[1].split(path)[-1].lstrip('/'), scm_info | ||||
except VCSError: | ||||
if not recursive: | ||||
continue | ||||
#check if this dir containts other repos for recursive scan | ||||
rec_path = os.path.join(p, dirpath) | ||||
if os.path.isdir(rec_path): | ||||
for inner_scm in _get_repos(rec_path): | ||||
yield inner_scm | ||||
return _get_repos(path) | ||||
r631 | ||||
r547 | def check_repo_fast(repo_name, base_path): | |||
r761 | """ | |||
r860 | Check given path for existence of directory | |||
r761 | :param repo_name: | |||
:param base_path: | ||||
:return False: if this directory is present | ||||
""" | ||||
r547 | if os.path.isdir(os.path.join(base_path, repo_name)):return False | |||
return True | ||||
def check_repo(repo_name, base_path, verify=True): | ||||
repo_path = os.path.join(base_path, repo_name) | ||||
try: | ||||
if not check_repo_fast(repo_name, base_path): | ||||
return False | ||||
r = hg.repository(ui.ui(), repo_path) | ||||
if verify: | ||||
hg.verify(r) | ||||
#here we hnow that repo exists it was verified | ||||
log.info('%s repo is already created', repo_name) | ||||
return False | ||||
except RepoError: | ||||
#it means that there is no valid repo there... | ||||
log.info('%s repo is free for creation', repo_name) | ||||
return True | ||||
def ask_ok(prompt, retries=4, complaint='Yes or no, please!'): | ||||
while True: | ||||
ok = raw_input(prompt) | ||||
if ok in ('y', 'ye', 'yes'): return True | ||||
if ok in ('n', 'no', 'nop', 'nope'): return False | ||||
retries = retries - 1 | ||||
if retries < 0: raise IOError | ||||
print complaint | ||||
r604 | ||||
r547 | #propagated from mercurial documentation | |||
ui_sections = ['alias', 'auth', | ||||
'decode/encode', 'defaults', | ||||
'diff', 'email', | ||||
'extensions', 'format', | ||||
'merge-patterns', 'merge-tools', | ||||
'hooks', 'http_proxy', | ||||
'smtp', 'patch', | ||||
'paths', 'profiling', | ||||
'server', 'trusted', | ||||
'ui', 'web', ] | ||||
r604 | ||||
def make_ui(read_from='file', path=None, checkpaths=True): | ||||
r930 | """A function that will read python rc files or database | |||
r547 | and make an mercurial ui object from read options | |||
r604 | :param path: path to mercurial config file | |||
:param checkpaths: check the path | ||||
:param read_from: read from 'file' or 'db' | ||||
r547 | """ | |||
baseui = ui.ui() | ||||
r724 | #clean the baseui object | |||
baseui._ocfg = config.config() | ||||
baseui._ucfg = config.config() | ||||
baseui._tcfg = config.config() | ||||
r547 | if read_from == 'file': | |||
if not os.path.isfile(path): | ||||
log.warning('Unable to read config file %s' % path) | ||||
return False | ||||
log.debug('reading hgrc from %s', path) | ||||
cfg = config.config() | ||||
cfg.read(path) | ||||
for section in ui_sections: | ||||
for k, v in cfg.items(section): | ||||
r724 | log.debug('settings ui from file[%s]%s:%s', section, k, v) | |||
r547 | baseui.setconfig(section, k, v) | |||
r724 | ||||
r604 | ||||
r547 | elif read_from == 'db': | |||
r756 | sa = meta.Session() | |||
ret = sa.query(RhodeCodeUi)\ | ||||
.options(FromCache("sql_cache_short", | ||||
"get_hg_ui_settings")).all() | ||||
r773 | ||||
r756 | hg_ui = ret | |||
r547 | for ui_ in hg_ui: | |||
if ui_.ui_active: | ||||
r773 | log.debug('settings ui from db[%s]%s:%s', ui_.ui_section, | |||
ui_.ui_key, ui_.ui_value) | ||||
r547 | baseui.setconfig(ui_.ui_section, ui_.ui_key, ui_.ui_value) | |||
r773 | ||||
meta.Session.remove() | ||||
r547 | return baseui | |||
r548 | def set_rhodecode_config(config): | |||
r860 | """Updates pylons config with new settings from database | |||
r756 | :param config: | |||
""" | ||||
from rhodecode.model.settings import SettingsModel | ||||
hgsettings = SettingsModel().get_app_settings() | ||||
r604 | ||||
r547 | for k, v in hgsettings.items(): | |||
config[k] = v | ||||
r692 | def invalidate_cache(cache_key, *args): | |||
r860 | """Puts cache invalidation task into db for | |||
r665 | further global cache invalidation | |||
""" | ||||
r860 | ||||
r692 | from rhodecode.model.scm import ScmModel | |||
if cache_key.startswith('get_repo_cached_'): | ||||
name = cache_key.split('get_repo_cached_')[-1] | ||||
ScmModel().mark_for_invalidation(name) | ||||
r604 | ||||
r547 | class EmptyChangeset(BaseChangeset): | |||
""" | ||||
r643 | An dummy empty changeset. It's possible to pass hash when creating | |||
an EmptyChangeset | ||||
r547 | """ | |||
r604 | ||||
r643 | def __init__(self, cs='0' * 40): | |||
self._empty_cs = cs | ||||
self.revision = -1 | ||||
self.message = '' | ||||
self.author = '' | ||||
self.date = '' | ||||
r636 | ||||
r547 | @LazyProperty | |||
def raw_id(self): | ||||
r860 | """Returns raw string identifying this changeset, useful for web | |||
r547 | representation. | |||
""" | ||||
r860 | ||||
r643 | return self._empty_cs | |||
r604 | ||||
r547 | @LazyProperty | |||
def short_id(self): | ||||
return self.raw_id[:12] | ||||
def get_file_changeset(self, path): | ||||
return self | ||||
r604 | ||||
r547 | def get_file_content(self, path): | |||
return u'' | ||||
r604 | ||||
r547 | def get_file_size(self, path): | |||
return 0 | ||||
r604 | ||||
r878 | def map_groups(groups): | |||
"""Checks for groups existence, and creates groups structures. | ||||
It returns last group in structure | ||||
:param groups: list of groups structure | ||||
""" | ||||
sa = meta.Session() | ||||
parent = None | ||||
group = None | ||||
for lvl, group_name in enumerate(groups[:-1]): | ||||
group = sa.query(Group).filter(Group.group_name == group_name).scalar() | ||||
if group is None: | ||||
group = Group(group_name, parent) | ||||
sa.add(group) | ||||
sa.commit() | ||||
parent = group | ||||
return group | ||||
r547 | def repo2db_mapper(initial_repo_list, remove_obsolete=False): | |||
r878 | """maps all repos given in initial_repo_list, non existing repositories | |||
are created, if remove_obsolete is True it also check for db entries | ||||
that are not in initial_repo_list and removes them. | ||||
:param initial_repo_list: list of repositories found by scanning methods | ||||
:param remove_obsolete: check for obsolete entries in database | ||||
r547 | """ | |||
r604 | ||||
r629 | sa = meta.Session() | |||
r692 | rm = RepoModel() | |||
r547 | user = sa.query(User).filter(User.admin == True).first() | |||
r1039 | added = [] | |||
r631 | for name, repo in initial_repo_list.items(): | |||
r878 | group = map_groups(name.split('/')) | |||
r735 | if not rm.get_by_repo_name(name, cache=False): | |||
r631 | log.info('repository %s not found creating default', name) | |||
r1039 | added.append(name) | |||
r547 | form_data = { | |||
'repo_name':name, | ||||
r652 | 'repo_type':repo.alias, | |||
r659 | 'description':repo.description \ | |||
if repo.description != 'unknown' else \ | ||||
'%s repository' % name, | ||||
r878 | 'private':False, | |||
'group_id':getattr(group, 'group_id', None) | ||||
r547 | } | |||
rm.create(form_data, user, just_db=True) | ||||
r1039 | removed = [] | |||
r547 | if remove_obsolete: | |||
#remove from database those repositories that are not in the filesystem | ||||
for repo in sa.query(Repository).all(): | ||||
if repo.repo_name not in initial_repo_list.keys(): | ||||
r1039 | removed.append(repo.repo_name) | |||
r547 | sa.delete(repo) | |||
sa.commit() | ||||
r1039 | return added, removed | |||
r547 | class OrderedDict(dict, DictMixin): | |||
def __init__(self, *args, **kwds): | ||||
if len(args) > 1: | ||||
raise TypeError('expected at most 1 arguments, got %d' % len(args)) | ||||
try: | ||||
self.__end | ||||
except AttributeError: | ||||
self.clear() | ||||
self.update(*args, **kwds) | ||||
def clear(self): | ||||
self.__end = end = [] | ||||
end += [None, end, end] # sentinel node for doubly linked list | ||||
self.__map = {} # key --> [key, prev, next] | ||||
dict.clear(self) | ||||
def __setitem__(self, key, value): | ||||
if key not in self: | ||||
end = self.__end | ||||
curr = end[1] | ||||
curr[2] = end[1] = self.__map[key] = [key, curr, end] | ||||
dict.__setitem__(self, key, value) | ||||
def __delitem__(self, key): | ||||
dict.__delitem__(self, key) | ||||
key, prev, next = self.__map.pop(key) | ||||
prev[2] = next | ||||
next[1] = prev | ||||
def __iter__(self): | ||||
end = self.__end | ||||
curr = end[2] | ||||
while curr is not end: | ||||
yield curr[0] | ||||
curr = curr[2] | ||||
def __reversed__(self): | ||||
end = self.__end | ||||
curr = end[1] | ||||
while curr is not end: | ||||
yield curr[0] | ||||
curr = curr[1] | ||||
def popitem(self, last=True): | ||||
if not self: | ||||
raise KeyError('dictionary is empty') | ||||
if last: | ||||
key = reversed(self).next() | ||||
else: | ||||
key = iter(self).next() | ||||
value = self.pop(key) | ||||
return key, value | ||||
def __reduce__(self): | ||||
items = [[k, self[k]] for k in self] | ||||
tmp = self.__map, self.__end | ||||
del self.__map, self.__end | ||||
inst_dict = vars(self).copy() | ||||
self.__map, self.__end = tmp | ||||
if inst_dict: | ||||
return (self.__class__, (items,), inst_dict) | ||||
return self.__class__, (items,) | ||||
def keys(self): | ||||
return list(self) | ||||
setdefault = DictMixin.setdefault | ||||
update = DictMixin.update | ||||
pop = DictMixin.pop | ||||
values = DictMixin.values | ||||
items = DictMixin.items | ||||
iterkeys = DictMixin.iterkeys | ||||
itervalues = DictMixin.itervalues | ||||
iteritems = DictMixin.iteritems | ||||
def __repr__(self): | ||||
if not self: | ||||
return '%s()' % (self.__class__.__name__,) | ||||
return '%s(%r)' % (self.__class__.__name__, self.items()) | ||||
def copy(self): | ||||
return self.__class__(self) | ||||
@classmethod | ||||
def fromkeys(cls, iterable, value=None): | ||||
d = cls() | ||||
for key in iterable: | ||||
d[key] = value | ||||
return d | ||||
def __eq__(self, other): | ||||
if isinstance(other, OrderedDict): | ||||
return len(self) == len(other) and self.items() == other.items() | ||||
return dict.__eq__(self, other) | ||||
def __ne__(self, other): | ||||
return not self == other | ||||
r785 | #set cache regions for beaker so celery can utilise it | |||
def add_cache(settings): | ||||
cache_settings = {'regions':None} | ||||
for key in settings.keys(): | ||||
for prefix in ['beaker.cache.', 'cache.']: | ||||
if key.startswith(prefix): | ||||
name = key.split(prefix)[1].strip() | ||||
cache_settings[name] = settings[key].strip() | ||||
if cache_settings['regions']: | ||||
for region in cache_settings['regions'].split(','): | ||||
region = region.strip() | ||||
region_settings = {} | ||||
for key, value in cache_settings.items(): | ||||
if key.startswith(region): | ||||
region_settings[key.split('.')[1]] = value | ||||
region_settings['expire'] = int(region_settings.get('expire', | ||||
60)) | ||||
region_settings.setdefault('lock_dir', | ||||
cache_settings.get('lock_dir')) | ||||
r1032 | region_settings.setdefault('data_dir', | |||
cache_settings.get('data_dir')) | ||||
r785 | if 'type' not in region_settings: | |||
region_settings['type'] = cache_settings.get('type', | ||||
'memory') | ||||
beaker.cache.cache_regions[region] = region_settings | ||||
r807 | def get_current_revision(): | |||
r860 | """Returns tuple of (number, id) from repository containing this package | |||
r807 | or None if repository could not be found. | |||
""" | ||||
r860 | ||||
r807 | try: | |||
from vcs import get_repo | ||||
from vcs.utils.helpers import get_scm | ||||
from vcs.exceptions import RepositoryError, VCSError | ||||
repopath = os.path.join(os.path.dirname(__file__), '..', '..') | ||||
scm = get_scm(repopath)[0] | ||||
repo = get_repo(path=repopath, alias=scm) | ||||
tip = repo.get_changeset() | ||||
return (tip.revision, tip.short_id) | ||||
except (ImportError, RepositoryError, VCSError), err: | ||||
logging.debug("Cannot retrieve rhodecode's revision. Original error " | ||||
"was: %s" % err) | ||||
return None | ||||
r785 | ||||
r547 | #=============================================================================== | |||
r629 | # TEST FUNCTIONS AND CREATORS | |||
r547 | #=============================================================================== | |||
def create_test_index(repo_location, full_index): | ||||
"""Makes default test index | ||||
r604 | :param repo_location: | |||
:param full_index: | ||||
r547 | """ | |||
from rhodecode.lib.indexers.daemon import WhooshIndexingDaemon | ||||
from rhodecode.lib.pidlock import DaemonLock, LockHeld | ||||
import shutil | ||||
r604 | ||||
r688 | index_location = os.path.join(repo_location, 'index') | |||
if os.path.exists(index_location): | ||||
shutil.rmtree(index_location) | ||||
r604 | ||||
r547 | try: | |||
l = DaemonLock() | ||||
r688 | WhooshIndexingDaemon(index_location=index_location, | |||
repo_location=repo_location)\ | ||||
r547 | .run(full_index=full_index) | |||
l.release() | ||||
except LockHeld: | ||||
r604 | pass | |||
r547 | def create_test_env(repos_test_path, config): | |||
"""Makes a fresh database and | ||||
install test repository into tmp dir | ||||
""" | ||||
from rhodecode.lib.db_manage import DbManage | ||||
r688 | from rhodecode.tests import HG_REPO, GIT_REPO, NEW_HG_REPO, NEW_GIT_REPO, \ | |||
HG_FORK, GIT_FORK, TESTS_TMP_PATH | ||||
r547 | import tarfile | |||
import shutil | ||||
from os.path import dirname as dn, join as jn, abspath | ||||
r604 | ||||
r547 | log = logging.getLogger('TestEnvCreator') | |||
# create logger | ||||
log.setLevel(logging.DEBUG) | ||||
log.propagate = True | ||||
# create console handler and set level to debug | ||||
ch = logging.StreamHandler() | ||||
ch.setLevel(logging.DEBUG) | ||||
r604 | ||||
r547 | # create formatter | |||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | ||||
r604 | ||||
r547 | # add formatter to ch | |||
ch.setFormatter(formatter) | ||||
r604 | ||||
r547 | # add ch to logger | |||
log.addHandler(ch) | ||||
r604 | ||||
r547 | #PART ONE create db | |||
r783 | dbconf = config['sqlalchemy.db1.url'] | |||
log.debug('making test db %s', dbconf) | ||||
r604 | ||||
r783 | dbmanage = DbManage(log_sql=True, dbconf=dbconf, root=config['here'], | |||
r552 | tests=True) | |||
r547 | dbmanage.create_tables(override=True) | |||
r1099 | dbmanage.create_settings(dbmanage.config_prompt(repos_test_path)) | |||
r547 | dbmanage.create_default_user() | |||
dbmanage.admin_prompt() | ||||
dbmanage.create_permissions() | ||||
dbmanage.populate_default_permissions() | ||||
r604 | ||||
r547 | #PART TWO make test repo | |||
r688 | log.debug('making test vcs repositories') | |||
#remove old one from previos tests | ||||
for r in [HG_REPO, GIT_REPO, NEW_HG_REPO, NEW_GIT_REPO, HG_FORK, GIT_FORK]: | ||||
r604 | ||||
r688 | if os.path.isdir(jn(TESTS_TMP_PATH, r)): | |||
log.debug('removing %s', r) | ||||
shutil.rmtree(jn(TESTS_TMP_PATH, r)) | ||||
#CREATE DEFAULT HG REPOSITORY | ||||
r547 | cur_dir = dn(dn(abspath(__file__))) | |||
r688 | tar = tarfile.open(jn(cur_dir, 'tests', "vcs_test_hg.tar.gz")) | |||
tar.extractall(jn(TESTS_TMP_PATH, HG_REPO)) | ||||
r547 | tar.close() | |||
r684 | ||||
r785 | ||||
#============================================================================== | ||||
# PASTER COMMANDS | ||||
#============================================================================== | ||||
class BasePasterCommand(Command): | ||||
""" | ||||
Abstract Base Class for paster commands. | ||||
The celery commands are somewhat aggressive about loading | ||||
celery.conf, and since our module sets the `CELERY_LOADER` | ||||
environment variable to our loader, we have to bootstrap a bit and | ||||
make sure we've had a chance to load the pylons config off of the | ||||
command line, otherwise everything fails. | ||||
""" | ||||
min_args = 1 | ||||
min_args_error = "Please provide a paster config file as an argument." | ||||
takes_config_file = 1 | ||||
requires_config_file = True | ||||
r837 | def notify_msg(self, msg, log=False): | |||
"""Make a notification to user, additionally if logger is passed | ||||
it logs this action using given logger | ||||
:param msg: message that will be printed to user | ||||
:param log: logging instance, to use to additionally log this message | ||||
""" | ||||
if log and isinstance(log, logging): | ||||
log(msg) | ||||
r785 | def run(self, args): | |||
""" | ||||
Overrides Command.run | ||||
Checks for a config file argument and loads it. | ||||
""" | ||||
if len(args) < self.min_args: | ||||
raise BadCommand( | ||||
self.min_args_error % {'min_args': self.min_args, | ||||
'actual_args': len(args)}) | ||||
# Decrement because we're going to lob off the first argument. | ||||
# @@ This is hacky | ||||
self.min_args -= 1 | ||||
self.bootstrap_config(args[0]) | ||||
self.update_parser() | ||||
return super(BasePasterCommand, self).run(args[1:]) | ||||
def update_parser(self): | ||||
""" | ||||
Abstract method. Allows for the class's parser to be updated | ||||
before the superclass's `run` method is called. Necessary to | ||||
allow options/arguments to be passed through to the underlying | ||||
celery command. | ||||
""" | ||||
raise NotImplementedError("Abstract Method.") | ||||
def bootstrap_config(self, conf): | ||||
""" | ||||
Loads the pylons configuration. | ||||
""" | ||||
from pylons import config as pylonsconfig | ||||
path_to_ini_file = os.path.realpath(conf) | ||||
conf = paste.deploy.appconfig('config:' + path_to_ini_file) | ||||
pylonsconfig.init_app(conf.global_conf, conf.local_conf) | ||||