##// END OF EJS Templates
tests: fixed tests for random dict order could sometimes break the tests
tests: fixed tests for random dict order could sometimes break the tests

File last commit:

r2043:338dc54d default
r2112:1550916c default
Show More
ssh_wrapper.py
606 lines | 20.5 KiB | text/x-python | PythonLexer
# -*- coding: utf-8 -*-
# Copyright (C) 2016-2017 RhodeCode GmbH
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License, version 3
# (only), as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# This program is dual-licensed. If you wish to learn more about the
# RhodeCode Enterprise Edition, including its added features, Support services,
# and proprietary license terms, please see https://rhodecode.com/licenses/
import os
import re
import sys
import json
import logging
import random
import signal
import tempfile
from subprocess import Popen, PIPE, check_output, CalledProcessError
import ConfigParser
import urllib2
import urlparse
import click
import pyramid.paster
log = logging.getLogger(__name__)
def setup_logging(ini_path, debug):
if debug:
# enabled rhodecode.ini controlled logging setup
pyramid.paster.setup_logging(ini_path)
else:
# configure logging in a mode that doesn't print anything.
# in case of regularly configured logging it gets printed out back
# to the client doing an SSH command.
logger = logging.getLogger('')
null = logging.NullHandler()
# add the handler to the root logger
logger.handlers = [null]
class SubversionTunnelWrapper(object):
process = None
def __init__(self, timeout, repositories_root=None, svn_path=None):
self.timeout = timeout
self.stdin = sys.stdin
self.repositories_root = repositories_root
self.svn_path = svn_path or 'svnserve'
self.svn_conf_fd, self.svn_conf_path = tempfile.mkstemp()
self.hooks_env_fd, self.hooks_env_path = tempfile.mkstemp()
self.read_only = False
self.create_svn_config()
def create_svn_config(self):
content = (
'[general]\n'
'hooks-env = {}\n').format(self.hooks_env_path)
with os.fdopen(self.svn_conf_fd, 'w') as config_file:
config_file.write(content)
def create_hooks_env(self):
content = (
'[default]\n'
'LANG = en_US.UTF-8\n')
if self.read_only:
content += 'SSH_READ_ONLY = 1\n'
with os.fdopen(self.hooks_env_fd, 'w') as hooks_env_file:
hooks_env_file.write(content)
def remove_configs(self):
os.remove(self.svn_conf_path)
os.remove(self.hooks_env_path)
def start(self):
config = ['--config-file', self.svn_conf_path]
command = [self.svn_path, '-t'] + config
if self.repositories_root:
command.extend(['-r', self.repositories_root])
self.process = Popen(command, stdin=PIPE)
def sync(self):
while self.process.poll() is None:
next_byte = self.stdin.read(1)
if not next_byte:
break
self.process.stdin.write(next_byte)
self.remove_configs()
@property
def return_code(self):
return self.process.returncode
def get_first_client_response(self):
signal.signal(signal.SIGALRM, self.interrupt)
signal.alarm(self.timeout)
first_response = self._read_first_client_response()
signal.alarm(0)
return (
self._parse_first_client_response(first_response)
if first_response else None)
def patch_first_client_response(self, response, **kwargs):
self.create_hooks_env()
data = response.copy()
data.update(kwargs)
data['url'] = self._svn_string(data['url'])
data['ra_client'] = self._svn_string(data['ra_client'])
data['client'] = data['client'] or ''
buffer_ = (
"( {version} ( {capabilities} ) {url}{ra_client}"
"( {client}) ) ".format(**data))
self.process.stdin.write(buffer_)
def fail(self, message):
print(
"( failure ( ( 210005 {message} 0: 0 ) ) )".format(
message=self._svn_string(message)))
self.remove_configs()
self.process.kill()
def interrupt(self, signum, frame):
self.fail("Exited by timeout")
def _svn_string(self, str_):
if not str_:
return ''
return '{length}:{string} '.format(length=len(str_), string=str_)
def _read_first_client_response(self):
buffer_ = ""
brackets_stack = []
while True:
next_byte = self.stdin.read(1)
buffer_ += next_byte
if next_byte == "(":
brackets_stack.append(next_byte)
elif next_byte == ")":
brackets_stack.pop()
elif next_byte == " " and not brackets_stack:
break
return buffer_
def _parse_first_client_response(self, buffer_):
"""
According to the Subversion RA protocol, the first request
should look like:
( version:number ( cap:word ... ) url:string ? ra-client:string
( ? client:string ) )
Please check https://svn.apache.org/repos/asf/subversion/trunk/
subversion/libsvn_ra_svn/protocol
"""
version_re = r'(?P<version>\d+)'
capabilities_re = r'\(\s(?P<capabilities>[\w\d\-\ ]+)\s\)'
url_re = r'\d+\:(?P<url>[\W\w]+)'
ra_client_re = r'(\d+\:(?P<ra_client>[\W\w]+)\s)'
client_re = r'(\d+\:(?P<client>[\W\w]+)\s)*'
regex = re.compile(
r'^\(\s{version}\s{capabilities}\s{url}\s{ra_client}'
r'\(\s{client}\)\s\)\s*$'.format(
version=version_re, capabilities=capabilities_re,
url=url_re, ra_client=ra_client_re, client=client_re))
matcher = regex.match(buffer_)
return matcher.groupdict() if matcher else None
class RhodeCodeApiClient(object):
def __init__(self, api_key, api_host):
self.api_key = api_key
self.api_host = api_host
if not api_host:
raise ValueError('api_key:{} not defined'.format(api_key))
if not api_host:
raise ValueError('api_host:{} not defined '.format(api_host))
def request(self, method, args):
id_ = random.randrange(1, 9999)
args = {
'id': id_,
'api_key': self.api_key,
'method': method,
'args': args
}
host = '{host}/_admin/api'.format(host=self.api_host)
log.debug('Doing API call to %s method:%s', host, method)
req = urllib2.Request(
host,
data=json.dumps(args),
headers={'content-type': 'text/plain'})
ret = urllib2.urlopen(req)
raw_json = ret.read()
json_data = json.loads(raw_json)
id_ret = json_data['id']
if id_ret != id_:
raise Exception('something went wrong. '
'ID mismatch got %s, expected %s | %s'
% (id_ret, id_, raw_json))
result = json_data['result']
error = json_data['error']
return result, error
def get_user_permissions(self, user, user_id):
result, error = self.request('get_user', {'userid': int(user_id)})
if result is None and error:
raise Exception(
'User "%s" not found or another error happened: %s!' % (
user, error))
log.debug(
'Given User: `%s` Fetched User: `%s`', user, result.get('username'))
return result.get('permissions').get('repositories')
def invalidate_cache(self, repo_name):
log.debug('Invalidate cache for repo:%s', repo_name)
return self.request('invalidate_cache', {'repoid': repo_name})
def get_repo_store(self):
result, error = self.request('get_repo_store', {})
return result
class VcsServer(object):
def __init__(self, user, user_permissions, config):
self.user = user
self.user_permissions = user_permissions
self.config = config
self.repo_name = None
self.repo_mode = None
self.store = {}
self.ini_path = ''
def run(self):
raise NotImplementedError()
def get_root_store(self):
root_store = self.store['path']
if not root_store.endswith('/'):
# always append trailing slash
root_store = root_store + '/'
return root_store
class MercurialServer(VcsServer):
read_only = False
def __init__(self, store, ini_path, repo_name,
user, user_permissions, config):
super(MercurialServer, self).__init__(user, user_permissions, config)
self.store = store
self.repo_name = repo_name
self.ini_path = ini_path
self.hg_path = config.get('app:main', 'ssh.executable.hg')
def run(self):
if not self._check_permissions():
return 2, False
tip_before = self.tip()
exit_code = os.system(self.command)
tip_after = self.tip()
return exit_code, tip_before != tip_after
def tip(self):
root = self.get_root_store()
command = (
'cd {root}; {hg_path} -R {root}{repo_name} tip --template "{{node}}\n"'
''.format(
root=root, hg_path=self.hg_path, repo_name=self.repo_name))
try:
tip = check_output(command, shell=True).strip()
except CalledProcessError:
tip = None
return tip
@property
def command(self):
root = self.get_root_store()
arguments = (
'--config hooks.pretxnchangegroup=\"false\"'
if self.read_only else '')
command = (
"cd {root}; {hg_path} -R {root}{repo_name} serve --stdio"
" {arguments}".format(
root=root, hg_path=self.hg_path, repo_name=self.repo_name,
arguments=arguments))
log.debug("Final CMD: %s", command)
return command
def _check_permissions(self):
permission = self.user_permissions.get(self.repo_name)
if permission is None or permission == 'repository.none':
log.error('repo not found or no permissions')
return False
elif permission in ['repository.admin', 'repository.write']:
log.info(
'Write Permissions for User "%s" granted to repo "%s"!' % (
self.user, self.repo_name))
else:
self.read_only = True
log.info(
'Only Read Only access for User "%s" granted to repo "%s"!',
self.user, self.repo_name)
return True
class GitServer(VcsServer):
def __init__(self, store, ini_path, repo_name, repo_mode,
user, user_permissions, config):
super(GitServer, self).__init__(user, user_permissions, config)
self.store = store
self.ini_path = ini_path
self.repo_name = repo_name
self.repo_mode = repo_mode
self.git_path = config.get('app:main', 'ssh.executable.git')
def run(self):
exit_code = self._check_permissions()
if exit_code:
return exit_code, False
self._update_environment()
exit_code = os.system(self.command)
return exit_code, self.repo_mode == "receive-pack"
@property
def command(self):
root = self.get_root_store()
command = "cd {root}; {git_path}-{mode} '{root}{repo_name}'".format(
root=root, git_path=self.git_path, mode=self.repo_mode,
repo_name=self.repo_name)
log.debug("Final CMD: %s", command)
return command
def _update_environment(self):
action = "push" if self.repo_mode == "receive-pack" else "pull",
scm_data = {
"ip": os.environ["SSH_CLIENT"].split()[0],
"username": self.user,
"action": action,
"repository": self.repo_name,
"scm": "git",
"config": self.ini_path,
"make_lock": None,
"locked_by": [None, None]
}
os.putenv("RC_SCM_DATA", json.dumps(scm_data))
def _check_permissions(self):
permission = self.user_permissions.get(self.repo_name)
log.debug(
'permission for %s on %s are: %s',
self.user, self.repo_name, permission)
if permission is None or permission == 'repository.none':
log.error('repo not found or no permissions')
return 2
elif permission in ['repository.admin', 'repository.write']:
log.info(
'Write Permissions for User "%s" granted to repo "%s"!',
self.user, self.repo_name)
elif (permission == 'repository.read' and
self.repo_mode == 'upload-pack'):
log.info(
'Only Read Only access for User "%s" granted to repo "%s"!',
self.user, self.repo_name)
elif (permission == 'repository.read'
and self.repo_mode == 'receive-pack'):
log.error(
'Only Read Only access for User "%s" granted to repo "%s"!'
' Failing!', self.user, self.repo_name)
return -3
else:
log.error('Cannot properly fetch user permission. '
'Return value is: %s', permission)
return -2
class SubversionServer(VcsServer):
def __init__(self, store, ini_path,
user, user_permissions, config):
super(SubversionServer, self).__init__(user, user_permissions, config)
self.store = store
self.ini_path = ini_path
# this is set in .run() from input stream
self.repo_name = None
self.svn_path = config.get('app:main', 'ssh.executable.svn')
def run(self):
root = self.get_root_store()
log.debug("Using subversion binaries from '%s'", self.svn_path)
self.tunnel = SubversionTunnelWrapper(
timeout=self.timeout, repositories_root=root, svn_path=self.svn_path)
self.tunnel.start()
first_response = self.tunnel.get_first_client_response()
if not first_response:
self.tunnel.fail("Repository name cannot be extracted")
return 1, False
url_parts = urlparse.urlparse(first_response['url'])
self.repo_name = url_parts.path.strip('/')
if not self._check_permissions():
self.tunnel.fail("Not enough permissions")
return 1, False
self.tunnel.patch_first_client_response(first_response)
self.tunnel.sync()
return self.tunnel.return_code, False
@property
def timeout(self):
timeout = 30
return timeout
def _check_permissions(self):
permission = self.user_permissions.get(self.repo_name)
if permission in ['repository.admin', 'repository.write']:
self.tunnel.read_only = False
return True
elif permission == 'repository.read':
self.tunnel.read_only = True
return True
else:
self.tunnel.fail("Not enough permissions for repository {}".format(
self.repo_name))
return False
class SshWrapper(object):
def __init__(self, command, mode, user, user_id, shell, ini_path):
self.command = command
self.mode = mode
self.user = user
self.user_id = user_id
self.shell = shell
self.ini_path = ini_path
self.config = self.parse_config(ini_path)
api_key = self.config.get('app:main', 'ssh.api_key')
api_host = self.config.get('app:main', 'ssh.api_host')
self.api = RhodeCodeApiClient(api_key, api_host)
def parse_config(self, config):
parser = ConfigParser.ConfigParser()
parser.read(config)
return parser
def get_repo_details(self, mode):
type_ = mode if mode in ['svn', 'hg', 'git'] else None
mode = mode
name = None
hg_pattern = r'^hg\s+\-R\s+(\S+)\s+serve\s+\-\-stdio$'
hg_match = re.match(hg_pattern, self.command)
if hg_match is not None:
type_ = 'hg'
name = hg_match.group(1).strip('/')
return type_, name, mode
git_pattern = (
r'^git-(receive-pack|upload-pack)\s\'[/]?(\S+?)(|\.git)\'$')
git_match = re.match(git_pattern, self.command)
if git_match is not None:
type_ = 'git'
name = git_match.group(2).strip('/')
mode = git_match.group(1)
return type_, name, mode
svn_pattern = r'^svnserve -t'
svn_match = re.match(svn_pattern, self.command)
if svn_match is not None:
type_ = 'svn'
# Repo name should be extracted from the input stream
return type_, name, mode
return type_, name, mode
def serve(self, vcs, repo, mode, user, permissions):
store = self.api.get_repo_store()
log.debug(
'VCS detected:`%s` mode: `%s` repo: %s', vcs, mode, repo)
if vcs == 'hg':
server = MercurialServer(
store=store, ini_path=self.ini_path,
repo_name=repo, user=user,
user_permissions=permissions, config=self.config)
return server.run()
elif vcs == 'git':
server = GitServer(
store=store, ini_path=self.ini_path,
repo_name=repo, repo_mode=mode, user=user,
user_permissions=permissions, config=self.config)
return server.run()
elif vcs == 'svn':
server = SubversionServer(
store=store, ini_path=self.ini_path,
user=user,
user_permissions=permissions, config=self.config)
return server.run()
else:
raise Exception('Unrecognised VCS: {}'.format(vcs))
def wrap(self):
mode = self.mode
user = self.user
user_id = self.user_id
shell = self.shell
scm_detected, scm_repo, scm_mode = self.get_repo_details(mode)
log.debug(
'Mode: `%s` User: `%s:%s` Shell: `%s` SSH Command: `\"%s\"` '
'SCM_DETECTED: `%s` SCM Mode: `%s` SCM Repo: `%s`',
mode, user, user_id, shell, self.command,
scm_detected, scm_mode, scm_repo)
try:
permissions = self.api.get_user_permissions(user, user_id)
except Exception as e:
log.exception('Failed to fetch user permissions')
return 1
if shell and self.command is None:
log.info(
'Dropping to shell, no command given and shell is allowed')
os.execl('/bin/bash', '-l')
exit_code = 1
elif scm_detected:
try:
exit_code, is_updated = self.serve(
scm_detected, scm_repo, scm_mode, user, permissions)
if exit_code == 0 and is_updated:
self.api.invalidate_cache(scm_repo)
except Exception:
log.exception('Error occurred during execution of SshWrapper')
exit_code = -1
elif self.command is None and shell is False:
log.error('No Command given.')
exit_code = -1
else:
log.error(
'Unhandled Command: "%s" Aborting.', self.command)
exit_code = -1
return exit_code
@click.command()
@click.argument('ini_path', type=click.Path(exists=True))
@click.option(
'--mode', '-m', required=False, default='auto',
type=click.Choice(['auto', 'vcs', 'git', 'hg', 'svn', 'test']),
help='mode of operation')
@click.option('--user', help='Username for which the command will be executed')
@click.option('--user-id', help='User ID for which the command will be executed')
@click.option('--shell', '-s', is_flag=True, help='Allow Shell')
@click.option('--debug', is_flag=True, help='Enabled detailed output logging')
def main(ini_path, mode, user, user_id, shell, debug):
setup_logging(ini_path, debug)
command = os.environ.get('SSH_ORIGINAL_COMMAND', '')
if not command and mode not in ['test']:
raise ValueError(
'Unable to fetch SSH_ORIGINAL_COMMAND from environment.'
'Please make sure this is set and available during execution '
'of this script.')
try:
ssh_wrapper = SshWrapper(command, mode, user, user_id, shell, ini_path)
except Exception:
log.exception('Failed to execute SshWrapper')
sys.exit(-5)
sys.exit(ssh_wrapper.wrap())