__init__.py
262 lines
| 9.0 KiB
| text/x-python
|
PythonLexer
r5088 | # Copyright (C) 2016-2023 RhodeCode GmbH | |||
r2187 | # | |||
# This program is free software: you can redistribute it and/or modify | ||||
# it under the terms of the GNU Affero General Public License, version 3 | ||||
# (only), as published by the Free Software Foundation. | ||||
# | ||||
# This program is distributed in the hope that it will be useful, | ||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||||
# GNU General Public License for more details. | ||||
# | ||||
# You should have received a copy of the GNU Affero General Public License | ||||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||||
# | ||||
# This program is dual-licensed. If you wish to learn more about the | ||||
# RhodeCode Enterprise Edition, including its added features, Support services, | ||||
# and proprietary license terms, please see https://rhodecode.com/licenses/ | ||||
import os | ||||
import re | ||||
import logging | ||||
import datetime | ||||
r4927 | import configparser | |||
r4947 | from sqlalchemy import Table | |||
r2187 | ||||
r4947 | from rhodecode.lib.utils2 import AttributeDict | |||
r2187 | from rhodecode.model.scm import ScmModel | |||
from .hg import MercurialServer | ||||
from .git import GitServer | ||||
from .svn import SubversionServer | ||||
log = logging.getLogger(__name__) | ||||
class SshWrapper(object): | ||||
r4702 | hg_cmd_pat = re.compile(r'^hg\s+\-R\s+(\S+)\s+serve\s+\-\-stdio$') | |||
git_cmd_pat = re.compile(r'^git-(receive-pack|upload-pack)\s\'[/]?(\S+?)(|\.git)\'$') | ||||
svn_cmd_pat = re.compile(r'^svnserve -t') | ||||
r2187 | ||||
def __init__(self, command, connection_info, mode, | ||||
r4947 | user, user_id, key_id: int, shell, ini_path: str, env): | |||
r2187 | self.command = command | |||
self.connection_info = connection_info | ||||
self.mode = mode | ||||
r4947 | self.username = user | |||
r2187 | self.user_id = user_id | |||
self.key_id = key_id | ||||
self.shell = shell | ||||
self.ini_path = ini_path | ||||
self.env = env | ||||
self.config = self.parse_config(ini_path) | ||||
self.server_impl = None | ||||
def parse_config(self, config_path): | ||||
r2355 | parser = configparser.ConfigParser() | |||
r2187 | parser.read(config_path) | |||
return parser | ||||
def update_key_access_time(self, key_id): | ||||
r4947 | from rhodecode.model.meta import raw_query_executor, Base | |||
table = Table('user_ssh_keys', Base.metadata, autoload=False) | ||||
r5072 | atime = datetime.datetime.utcnow() | |||
r4947 | stmt = ( | |||
table.update() | ||||
.where(table.c.ssh_key_id == key_id) | ||||
r5072 | .values(accessed_on=atime) | |||
# no MySQL Support for .returning :(( | ||||
#.returning(table.c.accessed_on, table.c.ssh_key_fingerprint) | ||||
r4947 | ) | |||
r5072 | res_count = None | |||
r4947 | with raw_query_executor() as session: | |||
result = session.execute(stmt) | ||||
if result.rowcount: | ||||
r5072 | res_count = result.rowcount | |||
r4947 | ||||
r5072 | if res_count: | |||
log.debug('Update key id:`%s` access time', key_id) | ||||
r4947 | ||||
def get_user(self, user_id): | ||||
user = AttributeDict() | ||||
# lazy load db imports | ||||
from rhodecode.model.db import User | ||||
dbuser = User.get(user_id) | ||||
if not dbuser: | ||||
return None | ||||
user.user_id = dbuser.user_id | ||||
user.username = dbuser.username | ||||
user.auth_user = dbuser.AuthUser() | ||||
return user | ||||
r2187 | ||||
def get_connection_info(self): | ||||
""" | ||||
connection_info | ||||
Identifies the client and server ends of the connection. | ||||
The variable contains four space-separated values: client IP address, | ||||
client port number, server IP address, and server port number. | ||||
""" | ||||
conn = dict( | ||||
client_ip=None, | ||||
client_port=None, | ||||
server_ip=None, | ||||
server_port=None, | ||||
) | ||||
info = self.connection_info.split(' ') | ||||
if len(info) == 4: | ||||
conn['client_ip'] = info[0] | ||||
conn['client_port'] = info[1] | ||||
conn['server_ip'] = info[2] | ||||
conn['server_port'] = info[3] | ||||
return conn | ||||
r4644 | def maybe_translate_repo_uid(self, repo_name): | |||
r4703 | _org_name = repo_name | |||
if _org_name.startswith('_'): | ||||
# remove format of _ID/subrepo | ||||
_org_name = _org_name.split('/', 1)[0] | ||||
r4644 | if repo_name.startswith('_'): | |||
from rhodecode.model.repo import RepoModel | ||||
r4703 | org_repo_name = repo_name | |||
log.debug('translating UID repo %s', org_repo_name) | ||||
r4644 | by_id_match = RepoModel().get_repo_by_id(repo_name) | |||
if by_id_match: | ||||
repo_name = by_id_match.repo_name | ||||
r4703 | log.debug('translation of UID repo %s got `%s`', org_repo_name, repo_name) | |||
return repo_name, _org_name | ||||
r4644 | ||||
r2187 | def get_repo_details(self, mode): | |||
vcs_type = mode if mode in ['svn', 'hg', 'git'] else None | ||||
repo_name = None | ||||
r4702 | hg_match = self.hg_cmd_pat.match(self.command) | |||
r2187 | if hg_match is not None: | |||
vcs_type = 'hg' | ||||
r4703 | repo_id = hg_match.group(1).strip('/') | |||
repo_name, org_name = self.maybe_translate_repo_uid(repo_id) | ||||
r2187 | return vcs_type, repo_name, mode | |||
r4702 | git_match = self.git_cmd_pat.match(self.command) | |||
r2187 | if git_match is not None: | |||
r4703 | mode = git_match.group(1) | |||
r2187 | vcs_type = 'git' | |||
r4703 | repo_id = git_match.group(2).strip('/') | |||
repo_name, org_name = self.maybe_translate_repo_uid(repo_id) | ||||
r2187 | return vcs_type, repo_name, mode | |||
r4702 | svn_match = self.svn_cmd_pat.match(self.command) | |||
r2187 | if svn_match is not None: | |||
vcs_type = 'svn' | ||||
r4281 | # Repo name should be extracted from the input stream, we're unable to | |||
# extract it at this point in execution | ||||
r2187 | return vcs_type, repo_name, mode | |||
return vcs_type, repo_name, mode | ||||
r2982 | def serve(self, vcs, repo, mode, user, permissions, branch_permissions): | |||
r2187 | store = ScmModel().repos_path | |||
r2982 | check_branch_perms = False | |||
detect_force_push = False | ||||
if branch_permissions: | ||||
check_branch_perms = True | ||||
detect_force_push = True | ||||
r2187 | log.debug( | |||
r2982 | 'VCS detected:`%s` mode: `%s` repo_name: %s, branch_permission_checks:%s', | |||
vcs, mode, repo, check_branch_perms) | ||||
# detect if we have to check branch permissions | ||||
extras = { | ||||
'detect_force_push': detect_force_push, | ||||
'check_branch_perms': check_branch_perms, | ||||
} | ||||
r2187 | ||||
if vcs == 'hg': | ||||
server = MercurialServer( | ||||
store=store, ini_path=self.ini_path, | ||||
repo_name=repo, user=user, | ||||
user_permissions=permissions, config=self.config, env=self.env) | ||||
self.server_impl = server | ||||
r2982 | return server.run(tunnel_extras=extras) | |||
r2187 | ||||
elif vcs == 'git': | ||||
server = GitServer( | ||||
store=store, ini_path=self.ini_path, | ||||
repo_name=repo, repo_mode=mode, user=user, | ||||
user_permissions=permissions, config=self.config, env=self.env) | ||||
self.server_impl = server | ||||
r2982 | return server.run(tunnel_extras=extras) | |||
r2187 | ||||
elif vcs == 'svn': | ||||
server = SubversionServer( | ||||
store=store, ini_path=self.ini_path, | ||||
repo_name=None, user=user, | ||||
user_permissions=permissions, config=self.config, env=self.env) | ||||
self.server_impl = server | ||||
r2982 | return server.run(tunnel_extras=extras) | |||
r2187 | ||||
else: | ||||
r5095 | raise Exception(f'Unrecognised VCS: {vcs}') | |||
r2187 | ||||
def wrap(self): | ||||
mode = self.mode | ||||
r4947 | username = self.username | |||
r2187 | user_id = self.user_id | |||
key_id = self.key_id | ||||
shell = self.shell | ||||
scm_detected, scm_repo, scm_mode = self.get_repo_details(mode) | ||||
log.debug( | ||||
r4947 | 'Mode: `%s` User: `name:%s : id:%s` Shell: `%s` SSH Command: `\"%s\"` ' | |||
r2187 | 'SCM_DETECTED: `%s` SCM Mode: `%s` SCM Repo: `%s`', | |||
r4947 | mode, username, user_id, shell, self.command, | |||
r2187 | scm_detected, scm_mode, scm_repo) | |||
r4947 | log.debug('SSH Connection info %s', self.get_connection_info()) | |||
r2187 | # update last access time for this key | |||
r4947 | if key_id: | |||
self.update_key_access_time(key_id) | ||||
r2187 | ||||
if shell and self.command is None: | ||||
r4281 | log.info('Dropping to shell, no command given and shell is allowed') | |||
r2187 | os.execl('/bin/bash', '-l') | |||
exit_code = 1 | ||||
elif scm_detected: | ||||
r4947 | user = self.get_user(user_id) | |||
r2206 | if not user: | |||
log.warning('User with id %s not found', user_id) | ||||
exit_code = -1 | ||||
return exit_code | ||||
r4947 | auth_user = user.auth_user | |||
r2187 | permissions = auth_user.permissions['repositories'] | |||
r2982 | repo_branch_permissions = auth_user.get_branch_permissions(scm_repo) | |||
r2187 | try: | |||
exit_code, is_updated = self.serve( | ||||
r2982 | scm_detected, scm_repo, scm_mode, user, permissions, | |||
repo_branch_permissions) | ||||
r2187 | except Exception: | |||
log.exception('Error occurred during execution of SshWrapper') | ||||
exit_code = -1 | ||||
elif self.command is None and shell is False: | ||||
log.error('No Command given.') | ||||
exit_code = -1 | ||||
else: | ||||
r4281 | log.error('Unhandled Command: "%s" Aborting.', self.command) | |||
r2187 | exit_code = -1 | |||
return exit_code | ||||