##// END OF EJS Templates
feat(ssh-wrapper-speedup): major rewrite of code to address imports problem with ssh-wrapper-v2...
feat(ssh-wrapper-speedup): major rewrite of code to address imports problem with ssh-wrapper-v2 - use bootstrapped settings rather than config - use more code split to make sure we don't import heavy code

File last commit:

r5325:359b5cac default
r5325:359b5cac default
Show More
__init__.py
417 lines | 15.1 KiB | text/x-python | PythonLexer
# Copyright (C) 2016-2023 RhodeCode GmbH
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License, version 3
# (only), as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# This program is dual-licensed. If you wish to learn more about the
# RhodeCode Enterprise Edition, including its added features, Support services,
# and proprietary license terms, please see https://rhodecode.com/licenses/
import os
import re
import logging
import datetime
from sqlalchemy import Table
from rhodecode.lib.api_utils import call_service_api
from rhodecode.lib.utils2 import AttributeDict
from rhodecode.lib.vcs.exceptions import ImproperlyConfiguredError
from .hg import MercurialServer
from .git import GitServer
from .svn import SubversionServer
log = logging.getLogger(__name__)
class SshWrapper(object):
hg_cmd_pat = re.compile(r'^hg\s+\-R\s+(\S+)\s+serve\s+\-\-stdio$')
git_cmd_pat = re.compile(r'^git-(receive-pack|upload-pack)\s\'[/]?(\S+?)(|\.git)\'$')
svn_cmd_pat = re.compile(r'^svnserve -t')
def __init__(self, command, connection_info, mode,
user, user_id, key_id: int, shell, ini_path: str, settings, env):
self.command = command
self.connection_info = connection_info
self.mode = mode
self.username = user
self.user_id = user_id
self.key_id = key_id
self.shell = shell
self.ini_path = ini_path
self.env = env
self.settings = settings
self.server_impl = None
def update_key_access_time(self, key_id):
from rhodecode.model.meta import raw_query_executor, Base
table = Table('user_ssh_keys', Base.metadata, autoload=False)
atime = datetime.datetime.utcnow()
stmt = (
table.update()
.where(table.c.ssh_key_id == key_id)
.values(accessed_on=atime)
# no MySQL Support for .returning :((
#.returning(table.c.accessed_on, table.c.ssh_key_fingerprint)
)
res_count = None
with raw_query_executor() as session:
result = session.execute(stmt)
if result.rowcount:
res_count = result.rowcount
if res_count:
log.debug('Update key id:`%s` access time', key_id)
def get_user(self, user_id):
user = AttributeDict()
# lazy load db imports
from rhodecode.model.db import User
dbuser = User.get(user_id)
if not dbuser:
return None
user.user_id = dbuser.user_id
user.username = dbuser.username
user.auth_user = dbuser.AuthUser()
return user
def get_connection_info(self):
"""
connection_info
Identifies the client and server ends of the connection.
The variable contains four space-separated values: client IP address,
client port number, server IP address, and server port number.
"""
conn = dict(
client_ip=None,
client_port=None,
server_ip=None,
server_port=None,
)
info = self.connection_info.split(' ')
if len(info) == 4:
conn['client_ip'] = info[0]
conn['client_port'] = info[1]
conn['server_ip'] = info[2]
conn['server_port'] = info[3]
return conn
def maybe_translate_repo_uid(self, repo_name):
_org_name = repo_name
if _org_name.startswith('_'):
# remove format of _ID/subrepo
_org_name = _org_name.split('/', 1)[0]
if repo_name.startswith('_'):
from rhodecode.model.repo import RepoModel
org_repo_name = repo_name
log.debug('translating UID repo %s', org_repo_name)
by_id_match = RepoModel().get_repo_by_id(repo_name)
if by_id_match:
repo_name = by_id_match.repo_name
log.debug('translation of UID repo %s got `%s`', org_repo_name, repo_name)
return repo_name, _org_name
def get_repo_details(self, mode):
vcs_type = mode if mode in ['svn', 'hg', 'git'] else None
repo_name = None
hg_match = self.hg_cmd_pat.match(self.command)
if hg_match is not None:
vcs_type = 'hg'
repo_id = hg_match.group(1).strip('/')
repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
return vcs_type, repo_name, mode
git_match = self.git_cmd_pat.match(self.command)
if git_match is not None:
mode = git_match.group(1)
vcs_type = 'git'
repo_id = git_match.group(2).strip('/')
repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
return vcs_type, repo_name, mode
svn_match = self.svn_cmd_pat.match(self.command)
if svn_match is not None:
vcs_type = 'svn'
# Repo name should be extracted from the input stream, we're unable to
# extract it at this point in execution
return vcs_type, repo_name, mode
return vcs_type, repo_name, mode
def serve(self, vcs, repo, mode, user, permissions, branch_permissions):
# TODO: remove this once we have .ini defined access path...
from rhodecode.model.scm import ScmModel
store = ScmModel().repos_path
check_branch_perms = False
detect_force_push = False
if branch_permissions:
check_branch_perms = True
detect_force_push = True
log.debug(
'VCS detected:`%s` mode: `%s` repo_name: %s, branch_permission_checks:%s',
vcs, mode, repo, check_branch_perms)
# detect if we have to check branch permissions
extras = {
'detect_force_push': detect_force_push,
'check_branch_perms': check_branch_perms,
'config': self.ini_path
}
if vcs == 'hg':
server = MercurialServer(
store=store, ini_path=self.ini_path,
repo_name=repo, user=user,
user_permissions=permissions, settings=self.settings, env=self.env)
self.server_impl = server
return server.run(tunnel_extras=extras)
elif vcs == 'git':
server = GitServer(
store=store, ini_path=self.ini_path,
repo_name=repo, repo_mode=mode, user=user,
user_permissions=permissions, settings=self.settings, env=self.env)
self.server_impl = server
return server.run(tunnel_extras=extras)
elif vcs == 'svn':
server = SubversionServer(
store=store, ini_path=self.ini_path,
repo_name=None, user=user,
user_permissions=permissions, settings=self.settings, env=self.env)
self.server_impl = server
return server.run(tunnel_extras=extras)
else:
raise Exception(f'Unrecognised VCS: {vcs}')
def wrap(self):
mode = self.mode
username = self.username
user_id = self.user_id
key_id = self.key_id
shell = self.shell
scm_detected, scm_repo, scm_mode = self.get_repo_details(mode)
log.debug(
'Mode: `%s` User: `name:%s : id:%s` Shell: `%s` SSH Command: `\"%s\"` '
'SCM_DETECTED: `%s` SCM Mode: `%s` SCM Repo: `%s`',
mode, username, user_id, shell, self.command,
scm_detected, scm_mode, scm_repo)
log.debug('SSH Connection info %s', self.get_connection_info())
# update last access time for this key
if key_id:
self.update_key_access_time(key_id)
if shell and self.command is None:
log.info('Dropping to shell, no command given and shell is allowed')
os.execl('/bin/bash', '-l')
exit_code = 1
elif scm_detected:
user = self.get_user(user_id)
if not user:
log.warning('User with id %s not found', user_id)
exit_code = -1
return exit_code
auth_user = user.auth_user
permissions = auth_user.permissions['repositories']
repo_branch_permissions = auth_user.get_branch_permissions(scm_repo)
try:
exit_code, is_updated = self.serve(
scm_detected, scm_repo, scm_mode, user, permissions,
repo_branch_permissions)
except Exception:
log.exception('Error occurred during execution of SshWrapper')
exit_code = -1
elif self.command is None and shell is False:
log.error('No Command given.')
exit_code = -1
else:
log.error('Unhandled Command: "%s" Aborting.', self.command)
exit_code = -1
return exit_code
class SshWrapperStandalone(SshWrapper):
"""
New version of SshWrapper designed to be depended only on service API
"""
repos_path = None
service_api_host: str
service_api_token: str
api_url: str
def __init__(self, command, connection_info, mode,
user, user_id, key_id: int, shell, ini_path: str, settings, env):
# validate our settings for making a standalone calls
try:
self.service_api_host = settings['app.service_api.host']
self.service_api_token = settings['app.service_api.token']
except KeyError:
raise ImproperlyConfiguredError(
"app.service_api.host or app.service_api.token are missing. "
"Please ensure that app.service_api.host and app.service_api.token are "
"defined inside of .ini configuration file."
)
try:
self.api_url = settings['rhodecode.api.url']
except KeyError:
raise ImproperlyConfiguredError(
"rhodecode.api.url is missing. "
"Please ensure that rhodecode.api.url is "
"defined inside of .ini configuration file."
)
super(SshWrapperStandalone, self).__init__(
command, connection_info, mode, user, user_id, key_id, shell, ini_path, settings, env)
@staticmethod
def parse_user_related_data(user_data):
user = AttributeDict()
user.user_id = user_data['user_id']
user.username = user_data['username']
user.repo_permissions = user_data['repo_permissions']
user.branch_permissions = user_data['branch_permissions']
return user
def wrap(self):
mode = self.mode
username = self.username
user_id = self.user_id
shell = self.shell
scm_detected, scm_repo, scm_mode = self.get_repo_details(mode)
log.debug(
'Mode: `%s` User: `name:%s : id:%s` Shell: `%s` SSH Command: `\"%s\"` '
'SCM_DETECTED: `%s` SCM Mode: `%s` SCM Repo: `%s`',
mode, username, user_id, shell, self.command,
scm_detected, scm_mode, scm_repo)
log.debug('SSH Connection info %s', self.get_connection_info())
if shell and self.command is None:
log.info('Dropping to shell, no command given and shell is allowed')
os.execl('/bin/bash', '-l')
exit_code = 1
elif scm_detected:
data = call_service_api(self.service_api_host, self.service_api_token, self.api_url, {
"method": "service_get_data_for_ssh_wrapper",
"args": {"user_id": user_id, "repo_name": scm_repo, "key_id": self.key_id}
})
user = self.parse_user_related_data(data)
if not user:
log.warning('User with id %s not found', user_id)
exit_code = -1
return exit_code
self.repos_path = data['repos_path']
permissions = user.repo_permissions
repo_branch_permissions = user.branch_permissions
try:
exit_code, is_updated = self.serve(
scm_detected, scm_repo, scm_mode, user, permissions,
repo_branch_permissions)
except Exception:
log.exception('Error occurred during execution of SshWrapper')
exit_code = -1
elif self.command is None and shell is False:
log.error('No Command given.')
exit_code = -1
else:
log.error('Unhandled Command: "%s" Aborting.', self.command)
exit_code = -1
return exit_code
def maybe_translate_repo_uid(self, repo_name):
_org_name = repo_name
if _org_name.startswith('_'):
_org_name = _org_name.split('/', 1)[0]
if repo_name.startswith('_'):
org_repo_name = repo_name
log.debug('translating UID repo %s', org_repo_name)
by_id_match = call_service_api(self.service_api_host, self.service_api_token, self.api_url, {
'method': 'service_get_repo_name_by_id',
"args": {"repo_id": repo_name}
})
if by_id_match:
repo_name = by_id_match['repo_name']
log.debug('translation of UID repo %s got `%s`', org_repo_name, repo_name)
return repo_name, _org_name
def serve(self, vcs, repo, mode, user, permissions, branch_permissions):
store = self.repos_path
check_branch_perms = False
detect_force_push = False
if branch_permissions:
check_branch_perms = True
detect_force_push = True
log.debug(
'VCS detected:`%s` mode: `%s` repo_name: %s, branch_permission_checks:%s',
vcs, mode, repo, check_branch_perms)
# detect if we have to check branch permissions
extras = {
'detect_force_push': detect_force_push,
'check_branch_perms': check_branch_perms,
'config': self.ini_path
}
match vcs:
case 'hg':
server = MercurialServer(
store=store, ini_path=self.ini_path,
repo_name=repo, user=user,
user_permissions=permissions, settings=self.settings, env=self.env)
case 'git':
server = GitServer(
store=store, ini_path=self.ini_path,
repo_name=repo, repo_mode=mode, user=user,
user_permissions=permissions, settings=self.settings, env=self.env)
case 'svn':
server = SubversionServer(
store=store, ini_path=self.ini_path,
repo_name=None, user=user,
user_permissions=permissions, settings=self.settings, env=self.env)
case _:
raise Exception(f'Unrecognised VCS: {vcs}')
self.server_impl = server
return server.run(tunnel_extras=extras)