##// END OF EJS Templates
code: flake8 fixes
super-admin -
r1063:5823ab6c python3
parent child Browse files
Show More
@@ -1,49 +1,53 b''
1 from vcsserver.lib._vendor.statsd import client_from_config
1 from vcsserver.lib._vendor.statsd import client_from_config
2
2
3
3
4 class StatsdClientNotInitialised(Exception):
4 class StatsdClientNotInitialised(Exception):
5 pass
5 pass
6
6
7
7
8 class _Singleton(type):
8 class _Singleton(type):
9 """A metaclass that creates a Singleton base class when called."""
9 """A metaclass that creates a Singleton base class when called."""
10
10
11 _instances = {}
11 _instances = {}
12
12
13 def __call__(cls, *args, **kwargs):
13 def __call__(cls, *args, **kwargs):
14 if cls not in cls._instances:
14 if cls not in cls._instances:
15 cls._instances[cls] = super(_Singleton, cls).__call__(*args, **kwargs)
15 cls._instances[cls] = super(_Singleton, cls).__call__(*args, **kwargs)
16 return cls._instances[cls]
16 return cls._instances[cls]
17
17
18
18
19 class Singleton(_Singleton("SingletonMeta", (object,), {})):
19 class Singleton(_Singleton("SingletonMeta", (object,), {})):
20 pass
20 pass
21
21
22
22
23 class StatsdClientClass(Singleton):
23 class StatsdClientClass(Singleton):
24 setup_run = False
24 setup_run = False
25 statsd_client = None
25 statsd_client = None
26 statsd = None
26 statsd = None
27 strict_mode_init = False
27
28
28 def __getattribute__(self, name):
29 def __getattribute__(self, name):
29
30
30 if name.startswith("statsd"):
31 if name.startswith("statsd"):
31 if self.setup_run:
32 if self.setup_run:
32 return super(StatsdClientClass, self).__getattribute__(name)
33 return super(StatsdClientClass, self).__getattribute__(name)
33 else:
34 else:
35 if self.strict_mode_init:
36 raise StatsdClientNotInitialised(f"requested key was {name}")
34 return None
37 return None
35 #raise StatsdClientNotInitialised("requested key was %s" % name)
36
38
37 return super(StatsdClientClass, self).__getattribute__(name)
39 return super(StatsdClientClass, self).__getattribute__(name)
38
40
39 def setup(self, settings):
41 def setup(self, settings):
40 """
42 """
41 Initialize the client
43 Initialize the client
42 """
44 """
45 strict_init_mode = settings.pop('statsd_strict_init', False)
46
43 statsd = client_from_config(settings)
47 statsd = client_from_config(settings)
44 self.statsd = statsd
48 self.statsd = statsd
45 self.statsd_client = statsd
49 self.statsd_client = statsd
46 self.setup_run = True
50 self.setup_run = True
47
51
48
52
49 StatsdClient = StatsdClientClass()
53 StatsdClient = StatsdClientClass()
@@ -1,160 +1,160 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2020 RhodeCode GmbH
2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18 import os
18 import os
19 import tempfile
19 import tempfile
20
20
21 from svn import client
21 from svn import client
22 from svn import core
22 from svn import core
23 from svn import ra
23 from svn import ra
24
24
25 from mercurial import error
25 from mercurial import error
26
26
27 from vcsserver.str_utils import safe_bytes
27 from vcsserver.str_utils import safe_bytes
28
28
29 core.svn_config_ensure(None)
29 core.svn_config_ensure(None)
30 svn_config = core.svn_config_get_config(None)
30 svn_config = core.svn_config_get_config(None)
31
31
32
32
33 class RaCallbacks(ra.Callbacks):
33 class RaCallbacks(ra.Callbacks):
34 @staticmethod
34 @staticmethod
35 def open_tmp_file(pool): # pragma: no cover
35 def open_tmp_file(pool): # pragma: no cover
36 (fd, fn) = tempfile.mkstemp()
36 (fd, fn) = tempfile.mkstemp()
37 os.close(fd)
37 os.close(fd)
38 return fn
38 return fn
39
39
40 @staticmethod
40 @staticmethod
41 def get_client_string(pool):
41 def get_client_string(pool):
42 return b'RhodeCode-subversion-url-checker'
42 return b'RhodeCode-subversion-url-checker'
43
43
44
44
45 class SubversionException(Exception):
45 class SubversionException(Exception):
46 pass
46 pass
47
47
48
48
49 class SubversionConnectionException(SubversionException):
49 class SubversionConnectionException(SubversionException):
50 """Exception raised when a generic error occurs when connecting to a repository."""
50 """Exception raised when a generic error occurs when connecting to a repository."""
51
51
52
52
53 def normalize_url(url):
53 def normalize_url(url):
54 if not url:
54 if not url:
55 return url
55 return url
56 if url.startswith(b'svn+http://') or url.startswith(b'svn+https://'):
56 if url.startswith(b'svn+http://') or url.startswith(b'svn+https://'):
57 url = url[4:]
57 url = url[4:]
58 url = url.rstrip(b'/')
58 url = url.rstrip(b'/')
59 return url
59 return url
60
60
61
61
62 def _create_auth_baton(pool):
62 def _create_auth_baton(pool):
63 """Create a Subversion authentication baton. """
63 """Create a Subversion authentication baton. """
64 # Give the client context baton a suite of authentication
64 # Give the client context baton a suite of authentication
65 # providers.h
65 # providers.h
66 platform_specific = [
66 platform_specific = [
67 'svn_auth_get_gnome_keyring_simple_provider',
67 'svn_auth_get_gnome_keyring_simple_provider',
68 'svn_auth_get_gnome_keyring_ssl_client_cert_pw_provider',
68 'svn_auth_get_gnome_keyring_ssl_client_cert_pw_provider',
69 'svn_auth_get_keychain_simple_provider',
69 'svn_auth_get_keychain_simple_provider',
70 'svn_auth_get_keychain_ssl_client_cert_pw_provider',
70 'svn_auth_get_keychain_ssl_client_cert_pw_provider',
71 'svn_auth_get_kwallet_simple_provider',
71 'svn_auth_get_kwallet_simple_provider',
72 'svn_auth_get_kwallet_ssl_client_cert_pw_provider',
72 'svn_auth_get_kwallet_ssl_client_cert_pw_provider',
73 'svn_auth_get_ssl_client_cert_file_provider',
73 'svn_auth_get_ssl_client_cert_file_provider',
74 'svn_auth_get_windows_simple_provider',
74 'svn_auth_get_windows_simple_provider',
75 'svn_auth_get_windows_ssl_server_trust_provider',
75 'svn_auth_get_windows_ssl_server_trust_provider',
76 ]
76 ]
77
77
78 providers = []
78 providers = []
79
79
80 for p in platform_specific:
80 for p in platform_specific:
81 if getattr(core, p, None) is not None:
81 if getattr(core, p, None) is not None:
82 try:
82 try:
83 providers.append(getattr(core, p)())
83 providers.append(getattr(core, p)())
84 except RuntimeError:
84 except RuntimeError:
85 pass
85 pass
86
86
87 providers += [
87 providers += [
88 client.get_simple_provider(),
88 client.get_simple_provider(),
89 client.get_username_provider(),
89 client.get_username_provider(),
90 client.get_ssl_client_cert_file_provider(),
90 client.get_ssl_client_cert_file_provider(),
91 client.get_ssl_client_cert_pw_file_provider(),
91 client.get_ssl_client_cert_pw_file_provider(),
92 client.get_ssl_server_trust_file_provider(),
92 client.get_ssl_server_trust_file_provider(),
93 ]
93 ]
94
94
95 return core.svn_auth_open(providers, pool)
95 return core.svn_auth_open(providers, pool)
96
96
97
97
98 class SubversionRepo(object):
98 class SubversionRepo(object):
99 """Wrapper for a Subversion repository.
99 """Wrapper for a Subversion repository.
100
100
101 It uses the SWIG Python bindings, see above for requirements.
101 It uses the SWIG Python bindings, see above for requirements.
102 """
102 """
103 def __init__(self, svn_url: bytes = b'', username: bytes = b'', password: bytes = b''):
103 def __init__(self, svn_url: bytes = b'', username: bytes = b'', password: bytes = b''):
104
104
105 self.username = username
105 self.username = username
106 self.password = password
106 self.password = password
107 self.svn_url = core.svn_path_canonicalize(svn_url)
107 self.svn_url = core.svn_path_canonicalize(svn_url)
108
108
109 self.auth_baton_pool = core.Pool()
109 self.auth_baton_pool = core.Pool()
110 self.auth_baton = _create_auth_baton(self.auth_baton_pool)
110 self.auth_baton = _create_auth_baton(self.auth_baton_pool)
111 # self.init_ra_and_client() assumes that a pool already exists
111 # self.init_ra_and_client() assumes that a pool already exists
112 self.pool = core.Pool()
112 self.pool = core.Pool()
113
113
114 self.ra = self.init_ra_and_client()
114 self.ra = self.init_ra_and_client()
115 self.uuid = ra.get_uuid(self.ra, self.pool)
115 self.uuid = ra.get_uuid(self.ra, self.pool)
116
116
117 def init_ra_and_client(self):
117 def init_ra_and_client(self):
118 """Initializes the RA and client layers, because sometimes getting
118 """Initializes the RA and client layers, because sometimes getting
119 unified diffs runs the remote server out of open files.
119 unified diffs runs the remote server out of open files.
120 """
120 """
121
121
122 if self.username:
122 if self.username:
123 core.svn_auth_set_parameter(self.auth_baton,
123 core.svn_auth_set_parameter(self.auth_baton,
124 core.SVN_AUTH_PARAM_DEFAULT_USERNAME,
124 core.SVN_AUTH_PARAM_DEFAULT_USERNAME,
125 self.username)
125 self.username)
126 if self.password:
126 if self.password:
127 core.svn_auth_set_parameter(self.auth_baton,
127 core.svn_auth_set_parameter(self.auth_baton,
128 core.SVN_AUTH_PARAM_DEFAULT_PASSWORD,
128 core.SVN_AUTH_PARAM_DEFAULT_PASSWORD,
129 self.password)
129 self.password)
130
130
131 callbacks = RaCallbacks()
131 callbacks = RaCallbacks()
132 callbacks.auth_baton = self.auth_baton
132 callbacks.auth_baton = self.auth_baton
133
133
134 try:
134 try:
135 return ra.open2(self.svn_url, callbacks, svn_config, self.pool)
135 return ra.open2(self.svn_url, callbacks, svn_config, self.pool)
136 except SubversionException as e:
136 except SubversionException as e:
137 # e.child contains a detailed error messages
137 # e.child contains a detailed error messages
138 msglist = []
138 msglist = []
139 svn_exc = e
139 svn_exc = e
140 while svn_exc:
140 while svn_exc:
141 if svn_exc.args[0]:
141 if svn_exc.args[0]:
142 msglist.append(svn_exc.args[0])
142 msglist.append(svn_exc.args[0])
143 svn_exc = svn_exc.child
143 svn_exc = svn_exc.child
144 msg = '\n'.join(msglist)
144 msg = '\n'.join(msglist)
145 raise SubversionConnectionException(msg)
145 raise SubversionConnectionException(msg)
146
146
147
147
148 class svnremoterepo(object):
148 class svnremoterepo(object):
149 """ the dumb wrapper for actual Subversion repositories """
149 """ the dumb wrapper for actual Subversion repositories """
150
150
151 def __init__(self, username: bytes = b'', password: bytes = b'', svn_url: bytes = b''):
151 def __init__(self, username: bytes = b'', password: bytes = b'', svn_url: bytes = b''):
152 self.username = username or b''
152 self.username = username or b''
153 self.password = password or b''
153 self.password = password or b''
154 self.path = normalize_url(svn_url)
154 self.path = normalize_url(svn_url)
155
155
156 def svn(self):
156 def svn(self):
157 try:
157 try:
158 return SubversionRepo(self.path, self.username, self.password)
158 return SubversionRepo(self.path, self.username, self.password)
159 except SubversionConnectionException as e:
159 except SubversionConnectionException as e:
160 raise error.Abort(safe_bytes(e))
160 raise error.Abort(safe_bytes(e))
@@ -1,1317 +1,1317 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2020 RhodeCode GmbH
2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18 import collections
18 import collections
19 import logging
19 import logging
20 import os
20 import os
21 import posixpath as vcspath
21 import posixpath as vcspath
22 import re
22 import re
23 import stat
23 import stat
24 import traceback
24 import traceback
25 import urllib.request, urllib.parse, urllib.error
25 import urllib.request, urllib.parse, urllib.error
26 import urllib.request, urllib.error, urllib.parse
26 import urllib.request, urllib.error, urllib.parse
27 from functools import wraps
27 from functools import wraps
28
28
29 import more_itertools
29 import more_itertools
30 import pygit2
30 import pygit2
31 from pygit2 import Repository as LibGit2Repo
31 from pygit2 import Repository as LibGit2Repo
32 from pygit2 import index as LibGit2Index
32 from pygit2 import index as LibGit2Index
33 from dulwich import index, objects
33 from dulwich import index, objects
34 from dulwich.client import HttpGitClient, LocalGitClient
34 from dulwich.client import HttpGitClient, LocalGitClient
35 from dulwich.errors import (
35 from dulwich.errors import (
36 NotGitRepository, ChecksumMismatch, WrongObjectException,
36 NotGitRepository, ChecksumMismatch, WrongObjectException,
37 MissingCommitError, ObjectMissing, HangupException,
37 MissingCommitError, ObjectMissing, HangupException,
38 UnexpectedCommandError)
38 UnexpectedCommandError)
39 from dulwich.repo import Repo as DulwichRepo
39 from dulwich.repo import Repo as DulwichRepo
40 from dulwich.server import update_server_info
40 from dulwich.server import update_server_info
41
41
42 from vcsserver import exceptions, settings, subprocessio
42 from vcsserver import exceptions, settings, subprocessio
43 from vcsserver.str_utils import safe_str, safe_int
43 from vcsserver.str_utils import safe_str, safe_int, safe_bytes
44 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, archive_repo
44 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, archive_repo
45 from vcsserver.hgcompat import (
45 from vcsserver.hgcompat import (
46 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
46 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
47 from vcsserver.git_lfs.lib import LFSOidStore
47 from vcsserver.git_lfs.lib import LFSOidStore
48 from vcsserver.vcs_base import RemoteBase
48 from vcsserver.vcs_base import RemoteBase
49
49
50 DIR_STAT = stat.S_IFDIR
50 DIR_STAT = stat.S_IFDIR
51 FILE_MODE = stat.S_IFMT
51 FILE_MODE = stat.S_IFMT
52 GIT_LINK = objects.S_IFGITLINK
52 GIT_LINK = objects.S_IFGITLINK
53 PEELED_REF_MARKER = '^{}'
53 PEELED_REF_MARKER = '^{}'
54
54
55
55
56 log = logging.getLogger(__name__)
56 log = logging.getLogger(__name__)
57
57
58
58
59 def reraise_safe_exceptions(func):
59 def reraise_safe_exceptions(func):
60 """Converts Dulwich exceptions to something neutral."""
60 """Converts Dulwich exceptions to something neutral."""
61
61
62 @wraps(func)
62 @wraps(func)
63 def wrapper(*args, **kwargs):
63 def wrapper(*args, **kwargs):
64 try:
64 try:
65 return func(*args, **kwargs)
65 return func(*args, **kwargs)
66 except (ChecksumMismatch, WrongObjectException, MissingCommitError, ObjectMissing,) as e:
66 except (ChecksumMismatch, WrongObjectException, MissingCommitError, ObjectMissing,) as e:
67 exc = exceptions.LookupException(org_exc=e)
67 exc = exceptions.LookupException(org_exc=e)
68 raise exc(safe_str(e))
68 raise exc(safe_str(e))
69 except (HangupException, UnexpectedCommandError) as e:
69 except (HangupException, UnexpectedCommandError) as e:
70 exc = exceptions.VcsException(org_exc=e)
70 exc = exceptions.VcsException(org_exc=e)
71 raise exc(safe_str(e))
71 raise exc(safe_str(e))
72 except Exception as e:
72 except Exception as e:
73 # NOTE(marcink): becuase of how dulwich handles some exceptions
73 # NOTE(marcink): becuase of how dulwich handles some exceptions
74 # (KeyError on empty repos), we cannot track this and catch all
74 # (KeyError on empty repos), we cannot track this and catch all
75 # exceptions, it's an exceptions from other handlers
75 # exceptions, it's an exceptions from other handlers
76 #if not hasattr(e, '_vcs_kind'):
76 #if not hasattr(e, '_vcs_kind'):
77 #log.exception("Unhandled exception in git remote call")
77 #log.exception("Unhandled exception in git remote call")
78 #raise_from_original(exceptions.UnhandledException)
78 #raise_from_original(exceptions.UnhandledException)
79 raise
79 raise
80 return wrapper
80 return wrapper
81
81
82
82
83 class Repo(DulwichRepo):
83 class Repo(DulwichRepo):
84 """
84 """
85 A wrapper for dulwich Repo class.
85 A wrapper for dulwich Repo class.
86
86
87 Since dulwich is sometimes keeping .idx file descriptors open, it leads to
87 Since dulwich is sometimes keeping .idx file descriptors open, it leads to
88 "Too many open files" error. We need to close all opened file descriptors
88 "Too many open files" error. We need to close all opened file descriptors
89 once the repo object is destroyed.
89 once the repo object is destroyed.
90 """
90 """
91 def __del__(self):
91 def __del__(self):
92 if hasattr(self, 'object_store'):
92 if hasattr(self, 'object_store'):
93 self.close()
93 self.close()
94
94
95
95
96 class Repository(LibGit2Repo):
96 class Repository(LibGit2Repo):
97
97
98 def __enter__(self):
98 def __enter__(self):
99 return self
99 return self
100
100
101 def __exit__(self, exc_type, exc_val, exc_tb):
101 def __exit__(self, exc_type, exc_val, exc_tb):
102 self.free()
102 self.free()
103
103
104
104
105 class GitFactory(RepoFactory):
105 class GitFactory(RepoFactory):
106 repo_type = 'git'
106 repo_type = 'git'
107
107
108 def _create_repo(self, wire, create, use_libgit2=False):
108 def _create_repo(self, wire, create, use_libgit2=False):
109 if use_libgit2:
109 if use_libgit2:
110 return Repository(wire['path'])
110 return Repository(wire['path'])
111 else:
111 else:
112 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
112 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
113 return Repo(repo_path)
113 return Repo(repo_path)
114
114
115 def repo(self, wire, create=False, use_libgit2=False):
115 def repo(self, wire, create=False, use_libgit2=False):
116 """
116 """
117 Get a repository instance for the given path.
117 Get a repository instance for the given path.
118 """
118 """
119 return self._create_repo(wire, create, use_libgit2)
119 return self._create_repo(wire, create, use_libgit2)
120
120
121 def repo_libgit2(self, wire):
121 def repo_libgit2(self, wire):
122 return self.repo(wire, use_libgit2=True)
122 return self.repo(wire, use_libgit2=True)
123
123
124
124
125 class GitRemote(RemoteBase):
125 class GitRemote(RemoteBase):
126
126
127 def __init__(self, factory):
127 def __init__(self, factory):
128 self._factory = factory
128 self._factory = factory
129 self._bulk_methods = {
129 self._bulk_methods = {
130 "date": self.date,
130 "date": self.date,
131 "author": self.author,
131 "author": self.author,
132 "branch": self.branch,
132 "branch": self.branch,
133 "message": self.message,
133 "message": self.message,
134 "parents": self.parents,
134 "parents": self.parents,
135 "_commit": self.revision,
135 "_commit": self.revision,
136 }
136 }
137
137
138 def _wire_to_config(self, wire):
138 def _wire_to_config(self, wire):
139 if 'config' in wire:
139 if 'config' in wire:
140 return dict([(x[0] + '_' + x[1], x[2]) for x in wire['config']])
140 return dict([(x[0] + '_' + x[1], x[2]) for x in wire['config']])
141 return {}
141 return {}
142
142
143 def _remote_conf(self, config):
143 def _remote_conf(self, config):
144 params = [
144 params = [
145 '-c', 'core.askpass=""',
145 '-c', 'core.askpass=""',
146 ]
146 ]
147 ssl_cert_dir = config.get('vcs_ssl_dir')
147 ssl_cert_dir = config.get('vcs_ssl_dir')
148 if ssl_cert_dir:
148 if ssl_cert_dir:
149 params.extend(['-c', 'http.sslCAinfo={}'.format(ssl_cert_dir)])
149 params.extend(['-c', 'http.sslCAinfo={}'.format(ssl_cert_dir)])
150 return params
150 return params
151
151
152 @reraise_safe_exceptions
152 @reraise_safe_exceptions
153 def discover_git_version(self):
153 def discover_git_version(self):
154 stdout, _ = self.run_git_command(
154 stdout, _ = self.run_git_command(
155 {}, ['--version'], _bare=True, _safe=True)
155 {}, ['--version'], _bare=True, _safe=True)
156 prefix = b'git version'
156 prefix = b'git version'
157 if stdout.startswith(prefix):
157 if stdout.startswith(prefix):
158 stdout = stdout[len(prefix):]
158 stdout = stdout[len(prefix):]
159 return stdout.strip()
159 return stdout.strip()
160
160
161 @reraise_safe_exceptions
161 @reraise_safe_exceptions
162 def is_empty(self, wire):
162 def is_empty(self, wire):
163 repo_init = self._factory.repo_libgit2(wire)
163 repo_init = self._factory.repo_libgit2(wire)
164 with repo_init as repo:
164 with repo_init as repo:
165
165
166 try:
166 try:
167 has_head = repo.head.name
167 has_head = repo.head.name
168 if has_head:
168 if has_head:
169 return False
169 return False
170
170
171 # NOTE(marcink): check again using more expensive method
171 # NOTE(marcink): check again using more expensive method
172 return repo.is_empty
172 return repo.is_empty
173 except Exception:
173 except Exception:
174 pass
174 pass
175
175
176 return True
176 return True
177
177
178 @reraise_safe_exceptions
178 @reraise_safe_exceptions
179 def assert_correct_path(self, wire):
179 def assert_correct_path(self, wire):
180 cache_on, context_uid, repo_id = self._cache_on(wire)
180 cache_on, context_uid, repo_id = self._cache_on(wire)
181 region = self._region(wire)
181 region = self._region(wire)
182
182
183 @region.conditional_cache_on_arguments(condition=cache_on)
183 @region.conditional_cache_on_arguments(condition=cache_on)
184 def _assert_correct_path(_context_uid, _repo_id):
184 def _assert_correct_path(_context_uid, _repo_id):
185 try:
185 try:
186 repo_init = self._factory.repo_libgit2(wire)
186 repo_init = self._factory.repo_libgit2(wire)
187 with repo_init as repo:
187 with repo_init as repo:
188 pass
188 pass
189 except pygit2.GitError:
189 except pygit2.GitError:
190 path = wire.get('path')
190 path = wire.get('path')
191 tb = traceback.format_exc()
191 tb = traceback.format_exc()
192 log.debug("Invalid Git path `%s`, tb: %s", path, tb)
192 log.debug("Invalid Git path `%s`, tb: %s", path, tb)
193 return False
193 return False
194
194
195 return True
195 return True
196 return _assert_correct_path(context_uid, repo_id)
196 return _assert_correct_path(context_uid, repo_id)
197
197
198 @reraise_safe_exceptions
198 @reraise_safe_exceptions
199 def bare(self, wire):
199 def bare(self, wire):
200 repo_init = self._factory.repo_libgit2(wire)
200 repo_init = self._factory.repo_libgit2(wire)
201 with repo_init as repo:
201 with repo_init as repo:
202 return repo.is_bare
202 return repo.is_bare
203
203
204 @reraise_safe_exceptions
204 @reraise_safe_exceptions
205 def blob_as_pretty_string(self, wire, sha):
205 def blob_as_pretty_string(self, wire, sha):
206 repo_init = self._factory.repo_libgit2(wire)
206 repo_init = self._factory.repo_libgit2(wire)
207 with repo_init as repo:
207 with repo_init as repo:
208 blob_obj = repo[sha]
208 blob_obj = repo[sha]
209 blob = blob_obj.data
209 blob = blob_obj.data
210 return blob
210 return blob
211
211
212 @reraise_safe_exceptions
212 @reraise_safe_exceptions
213 def blob_raw_length(self, wire, sha):
213 def blob_raw_length(self, wire, sha):
214 cache_on, context_uid, repo_id = self._cache_on(wire)
214 cache_on, context_uid, repo_id = self._cache_on(wire)
215 region = self._region(wire)
215 region = self._region(wire)
216
216
217 @region.conditional_cache_on_arguments(condition=cache_on)
217 @region.conditional_cache_on_arguments(condition=cache_on)
218 def _blob_raw_length(_repo_id, _sha):
218 def _blob_raw_length(_repo_id, _sha):
219
219
220 repo_init = self._factory.repo_libgit2(wire)
220 repo_init = self._factory.repo_libgit2(wire)
221 with repo_init as repo:
221 with repo_init as repo:
222 blob = repo[sha]
222 blob = repo[sha]
223 return blob.size
223 return blob.size
224
224
225 return _blob_raw_length(repo_id, sha)
225 return _blob_raw_length(repo_id, sha)
226
226
227 def _parse_lfs_pointer(self, raw_content):
227 def _parse_lfs_pointer(self, raw_content):
228 spec_string = b'version https://git-lfs.github.com/spec'
228 spec_string = b'version https://git-lfs.github.com/spec'
229 if raw_content and raw_content.startswith(spec_string):
229 if raw_content and raw_content.startswith(spec_string):
230
230
231 pattern = re.compile(rb"""
231 pattern = re.compile(rb"""
232 (?:\n)?
232 (?:\n)?
233 ^version[ ]https://git-lfs\.github\.com/spec/(?P<spec_ver>v\d+)\n
233 ^version[ ]https://git-lfs\.github\.com/spec/(?P<spec_ver>v\d+)\n
234 ^oid[ ] sha256:(?P<oid_hash>[0-9a-f]{64})\n
234 ^oid[ ] sha256:(?P<oid_hash>[0-9a-f]{64})\n
235 ^size[ ](?P<oid_size>[0-9]+)\n
235 ^size[ ](?P<oid_size>[0-9]+)\n
236 (?:\n)?
236 (?:\n)?
237 """, re.VERBOSE | re.MULTILINE)
237 """, re.VERBOSE | re.MULTILINE)
238 match = pattern.match(raw_content)
238 match = pattern.match(raw_content)
239 if match:
239 if match:
240 return match.groupdict()
240 return match.groupdict()
241
241
242 return {}
242 return {}
243
243
244 @reraise_safe_exceptions
244 @reraise_safe_exceptions
245 def is_large_file(self, wire, commit_id):
245 def is_large_file(self, wire, commit_id):
246 cache_on, context_uid, repo_id = self._cache_on(wire)
246 cache_on, context_uid, repo_id = self._cache_on(wire)
247 region = self._region(wire)
247 region = self._region(wire)
248
248
249 @region.conditional_cache_on_arguments(condition=cache_on)
249 @region.conditional_cache_on_arguments(condition=cache_on)
250 def _is_large_file(_repo_id, _sha):
250 def _is_large_file(_repo_id, _sha):
251 repo_init = self._factory.repo_libgit2(wire)
251 repo_init = self._factory.repo_libgit2(wire)
252 with repo_init as repo:
252 with repo_init as repo:
253 blob = repo[commit_id]
253 blob = repo[commit_id]
254 if blob.is_binary:
254 if blob.is_binary:
255 return {}
255 return {}
256
256
257 return self._parse_lfs_pointer(blob.data)
257 return self._parse_lfs_pointer(blob.data)
258
258
259 return _is_large_file(repo_id, commit_id)
259 return _is_large_file(repo_id, commit_id)
260
260
261 @reraise_safe_exceptions
261 @reraise_safe_exceptions
262 def is_binary(self, wire, tree_id):
262 def is_binary(self, wire, tree_id):
263 cache_on, context_uid, repo_id = self._cache_on(wire)
263 cache_on, context_uid, repo_id = self._cache_on(wire)
264 region = self._region(wire)
264 region = self._region(wire)
265
265
266 @region.conditional_cache_on_arguments(condition=cache_on)
266 @region.conditional_cache_on_arguments(condition=cache_on)
267 def _is_binary(_repo_id, _tree_id):
267 def _is_binary(_repo_id, _tree_id):
268 repo_init = self._factory.repo_libgit2(wire)
268 repo_init = self._factory.repo_libgit2(wire)
269 with repo_init as repo:
269 with repo_init as repo:
270 blob_obj = repo[tree_id]
270 blob_obj = repo[tree_id]
271 return blob_obj.is_binary
271 return blob_obj.is_binary
272
272
273 return _is_binary(repo_id, tree_id)
273 return _is_binary(repo_id, tree_id)
274
274
275 @reraise_safe_exceptions
275 @reraise_safe_exceptions
276 def in_largefiles_store(self, wire, oid):
276 def in_largefiles_store(self, wire, oid):
277 conf = self._wire_to_config(wire)
277 conf = self._wire_to_config(wire)
278 repo_init = self._factory.repo_libgit2(wire)
278 repo_init = self._factory.repo_libgit2(wire)
279 with repo_init as repo:
279 with repo_init as repo:
280 repo_name = repo.path
280 repo_name = repo.path
281
281
282 store_location = conf.get('vcs_git_lfs_store_location')
282 store_location = conf.get('vcs_git_lfs_store_location')
283 if store_location:
283 if store_location:
284
284
285 store = LFSOidStore(
285 store = LFSOidStore(
286 oid=oid, repo=repo_name, store_location=store_location)
286 oid=oid, repo=repo_name, store_location=store_location)
287 return store.has_oid()
287 return store.has_oid()
288
288
289 return False
289 return False
290
290
291 @reraise_safe_exceptions
291 @reraise_safe_exceptions
292 def store_path(self, wire, oid):
292 def store_path(self, wire, oid):
293 conf = self._wire_to_config(wire)
293 conf = self._wire_to_config(wire)
294 repo_init = self._factory.repo_libgit2(wire)
294 repo_init = self._factory.repo_libgit2(wire)
295 with repo_init as repo:
295 with repo_init as repo:
296 repo_name = repo.path
296 repo_name = repo.path
297
297
298 store_location = conf.get('vcs_git_lfs_store_location')
298 store_location = conf.get('vcs_git_lfs_store_location')
299 if store_location:
299 if store_location:
300 store = LFSOidStore(
300 store = LFSOidStore(
301 oid=oid, repo=repo_name, store_location=store_location)
301 oid=oid, repo=repo_name, store_location=store_location)
302 return store.oid_path
302 return store.oid_path
303 raise ValueError('Unable to fetch oid with path {}'.format(oid))
303 raise ValueError('Unable to fetch oid with path {}'.format(oid))
304
304
305 @reraise_safe_exceptions
305 @reraise_safe_exceptions
306 def bulk_request(self, wire, rev, pre_load):
306 def bulk_request(self, wire, rev, pre_load):
307 cache_on, context_uid, repo_id = self._cache_on(wire)
307 cache_on, context_uid, repo_id = self._cache_on(wire)
308 region = self._region(wire)
308 region = self._region(wire)
309
309
310 @region.conditional_cache_on_arguments(condition=cache_on)
310 @region.conditional_cache_on_arguments(condition=cache_on)
311 def _bulk_request(_repo_id, _rev, _pre_load):
311 def _bulk_request(_repo_id, _rev, _pre_load):
312 result = {}
312 result = {}
313 for attr in pre_load:
313 for attr in pre_load:
314 try:
314 try:
315 method = self._bulk_methods[attr]
315 method = self._bulk_methods[attr]
316 args = [wire, rev]
316 args = [wire, rev]
317 result[attr] = method(*args)
317 result[attr] = method(*args)
318 except KeyError as e:
318 except KeyError as e:
319 raise exceptions.VcsException(e)(
319 raise exceptions.VcsException(e)(
320 "Unknown bulk attribute: %s" % attr)
320 "Unknown bulk attribute: %s" % attr)
321 return result
321 return result
322
322
323 return _bulk_request(repo_id, rev, sorted(pre_load))
323 return _bulk_request(repo_id, rev, sorted(pre_load))
324
324
325 def _build_opener(self, url):
325 def _build_opener(self, url):
326 handlers = []
326 handlers = []
327 url_obj = url_parser(url)
327 url_obj = url_parser(url)
328 _, authinfo = url_obj.authinfo()
328 _, authinfo = url_obj.authinfo()
329
329
330 if authinfo:
330 if authinfo:
331 # create a password manager
331 # create a password manager
332 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
332 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
333 passmgr.add_password(*authinfo)
333 passmgr.add_password(*authinfo)
334
334
335 handlers.extend((httpbasicauthhandler(passmgr),
335 handlers.extend((httpbasicauthhandler(passmgr),
336 httpdigestauthhandler(passmgr)))
336 httpdigestauthhandler(passmgr)))
337
337
338 return urllib.request.build_opener(*handlers)
338 return urllib.request.build_opener(*handlers)
339
339
340 def _type_id_to_name(self, type_id: int):
340 def _type_id_to_name(self, type_id: int):
341 return {
341 return {
342 1: 'commit',
342 1: 'commit',
343 2: 'tree',
343 2: 'tree',
344 3: 'blob',
344 3: 'blob',
345 4: 'tag'
345 4: 'tag'
346 }[type_id]
346 }[type_id]
347
347
348 @reraise_safe_exceptions
348 @reraise_safe_exceptions
349 def check_url(self, url, config):
349 def check_url(self, url, config):
350 url_obj = url_parser(url)
350 url_obj = url_parser(url)
351 test_uri, _ = url_obj.authinfo()
351 test_uri, _ = url_obj.authinfo()
352 url_obj.passwd = '*****' if url_obj.passwd else url_obj.passwd
352 url_obj.passwd = '*****' if url_obj.passwd else url_obj.passwd
353 url_obj.query = obfuscate_qs(url_obj.query)
353 url_obj.query = obfuscate_qs(url_obj.query)
354 cleaned_uri = str(url_obj)
354 cleaned_uri = str(url_obj)
355 log.info("Checking URL for remote cloning/import: %s", cleaned_uri)
355 log.info("Checking URL for remote cloning/import: %s", cleaned_uri)
356
356
357 if not test_uri.endswith('info/refs'):
357 if not test_uri.endswith('info/refs'):
358 test_uri = test_uri.rstrip('/') + '/info/refs'
358 test_uri = test_uri.rstrip('/') + '/info/refs'
359
359
360 o = self._build_opener(url)
360 o = self._build_opener(url)
361 o.addheaders = [('User-Agent', 'git/1.7.8.0')] # fake some git
361 o.addheaders = [('User-Agent', 'git/1.7.8.0')] # fake some git
362
362
363 q = {"service": 'git-upload-pack'}
363 q = {"service": 'git-upload-pack'}
364 qs = '?%s' % urllib.parse.urlencode(q)
364 qs = '?%s' % urllib.parse.urlencode(q)
365 cu = "%s%s" % (test_uri, qs)
365 cu = "%s%s" % (test_uri, qs)
366 req = urllib.request.Request(cu, None, {})
366 req = urllib.request.Request(cu, None, {})
367
367
368 try:
368 try:
369 log.debug("Trying to open URL %s", cleaned_uri)
369 log.debug("Trying to open URL %s", cleaned_uri)
370 resp = o.open(req)
370 resp = o.open(req)
371 if resp.code != 200:
371 if resp.code != 200:
372 raise exceptions.URLError()('Return Code is not 200')
372 raise exceptions.URLError()('Return Code is not 200')
373 except Exception as e:
373 except Exception as e:
374 log.warning("URL cannot be opened: %s", cleaned_uri, exc_info=True)
374 log.warning("URL cannot be opened: %s", cleaned_uri, exc_info=True)
375 # means it cannot be cloned
375 # means it cannot be cloned
376 raise exceptions.URLError(e)("[%s] org_exc: %s" % (cleaned_uri, e))
376 raise exceptions.URLError(e)("[%s] org_exc: %s" % (cleaned_uri, e))
377
377
378 # now detect if it's proper git repo
378 # now detect if it's proper git repo
379 gitdata = resp.read()
379 gitdata = resp.read()
380 if 'service=git-upload-pack' in gitdata:
380 if 'service=git-upload-pack' in gitdata:
381 pass
381 pass
382 elif re.findall(r'[0-9a-fA-F]{40}\s+refs', gitdata):
382 elif re.findall(r'[0-9a-fA-F]{40}\s+refs', gitdata):
383 # old style git can return some other format !
383 # old style git can return some other format !
384 pass
384 pass
385 else:
385 else:
386 raise exceptions.URLError()(
386 raise exceptions.URLError()(
387 "url [%s] does not look like an git" % (cleaned_uri,))
387 "url [%s] does not look like an git" % (cleaned_uri,))
388
388
389 return True
389 return True
390
390
391 @reraise_safe_exceptions
391 @reraise_safe_exceptions
392 def clone(self, wire, url, deferred, valid_refs, update_after_clone):
392 def clone(self, wire, url, deferred, valid_refs, update_after_clone):
393 # TODO(marcink): deprecate this method. Last i checked we don't use it anymore
393 # TODO(marcink): deprecate this method. Last i checked we don't use it anymore
394 remote_refs = self.pull(wire, url, apply_refs=False)
394 remote_refs = self.pull(wire, url, apply_refs=False)
395 repo = self._factory.repo(wire)
395 repo = self._factory.repo(wire)
396 if isinstance(valid_refs, list):
396 if isinstance(valid_refs, list):
397 valid_refs = tuple(valid_refs)
397 valid_refs = tuple(valid_refs)
398
398
399 for k in remote_refs:
399 for k in remote_refs:
400 # only parse heads/tags and skip so called deferred tags
400 # only parse heads/tags and skip so called deferred tags
401 if k.startswith(valid_refs) and not k.endswith(deferred):
401 if k.startswith(valid_refs) and not k.endswith(deferred):
402 repo[k] = remote_refs[k]
402 repo[k] = remote_refs[k]
403
403
404 if update_after_clone:
404 if update_after_clone:
405 # we want to checkout HEAD
405 # we want to checkout HEAD
406 repo["HEAD"] = remote_refs["HEAD"]
406 repo["HEAD"] = remote_refs["HEAD"]
407 index.build_index_from_tree(repo.path, repo.index_path(),
407 index.build_index_from_tree(repo.path, repo.index_path(),
408 repo.object_store, repo["HEAD"].tree)
408 repo.object_store, repo["HEAD"].tree)
409
409
410 @reraise_safe_exceptions
410 @reraise_safe_exceptions
411 def branch(self, wire, commit_id):
411 def branch(self, wire, commit_id):
412 cache_on, context_uid, repo_id = self._cache_on(wire)
412 cache_on, context_uid, repo_id = self._cache_on(wire)
413 region = self._region(wire)
413 region = self._region(wire)
414 @region.conditional_cache_on_arguments(condition=cache_on)
414 @region.conditional_cache_on_arguments(condition=cache_on)
415 def _branch(_context_uid, _repo_id, _commit_id):
415 def _branch(_context_uid, _repo_id, _commit_id):
416 regex = re.compile('^refs/heads')
416 regex = re.compile('^refs/heads')
417
417
418 def filter_with(ref):
418 def filter_with(ref):
419 return regex.match(ref[0]) and ref[1] == _commit_id
419 return regex.match(ref[0]) and ref[1] == _commit_id
420
420
421 branches = list(filter(filter_with, list(self.get_refs(wire).items())))
421 branches = list(filter(filter_with, list(self.get_refs(wire).items())))
422 return [x[0].split('refs/heads/')[-1] for x in branches]
422 return [x[0].split('refs/heads/')[-1] for x in branches]
423
423
424 return _branch(context_uid, repo_id, commit_id)
424 return _branch(context_uid, repo_id, commit_id)
425
425
426 @reraise_safe_exceptions
426 @reraise_safe_exceptions
427 def commit_branches(self, wire, commit_id):
427 def commit_branches(self, wire, commit_id):
428 cache_on, context_uid, repo_id = self._cache_on(wire)
428 cache_on, context_uid, repo_id = self._cache_on(wire)
429 region = self._region(wire)
429 region = self._region(wire)
430 @region.conditional_cache_on_arguments(condition=cache_on)
430 @region.conditional_cache_on_arguments(condition=cache_on)
431 def _commit_branches(_context_uid, _repo_id, _commit_id):
431 def _commit_branches(_context_uid, _repo_id, _commit_id):
432 repo_init = self._factory.repo_libgit2(wire)
432 repo_init = self._factory.repo_libgit2(wire)
433 with repo_init as repo:
433 with repo_init as repo:
434 branches = [x for x in repo.branches.with_commit(_commit_id)]
434 branches = [x for x in repo.branches.with_commit(_commit_id)]
435 return branches
435 return branches
436
436
437 return _commit_branches(context_uid, repo_id, commit_id)
437 return _commit_branches(context_uid, repo_id, commit_id)
438
438
439 @reraise_safe_exceptions
439 @reraise_safe_exceptions
440 def add_object(self, wire, content):
440 def add_object(self, wire, content):
441 repo_init = self._factory.repo_libgit2(wire)
441 repo_init = self._factory.repo_libgit2(wire)
442 with repo_init as repo:
442 with repo_init as repo:
443 blob = objects.Blob()
443 blob = objects.Blob()
444 blob.set_raw_string(content)
444 blob.set_raw_string(content)
445 repo.object_store.add_object(blob)
445 repo.object_store.add_object(blob)
446 return blob.id
446 return blob.id
447
447
448 # TODO: this is quite complex, check if that can be simplified
448 # TODO: this is quite complex, check if that can be simplified
449 @reraise_safe_exceptions
449 @reraise_safe_exceptions
450 def commit(self, wire, commit_data, branch, commit_tree, updated, removed):
450 def commit(self, wire, commit_data, branch, commit_tree, updated, removed):
451 # Defines the root tree
451 # Defines the root tree
452 class _Root(object):
452 class _Root(object):
453 def __repr__(self):
453 def __repr__(self):
454 return 'ROOT TREE'
454 return 'ROOT TREE'
455 ROOT = _Root()
455 ROOT = _Root()
456
456
457 repo = self._factory.repo(wire)
457 repo = self._factory.repo(wire)
458 object_store = repo.object_store
458 object_store = repo.object_store
459
459
460 # Create tree and populates it with blobs
460 # Create tree and populates it with blobs
461
461
462 if commit_tree and repo[commit_tree]:
462 if commit_tree and repo[commit_tree]:
463 git_commit = repo[commit_data['parents'][0]]
463 git_commit = repo[commit_data['parents'][0]]
464 commit_tree = repo[git_commit.tree] # root tree
464 commit_tree = repo[git_commit.tree] # root tree
465 else:
465 else:
466 commit_tree = objects.Tree()
466 commit_tree = objects.Tree()
467
467
468 for node in updated:
468 for node in updated:
469 # Compute subdirs if needed
469 # Compute subdirs if needed
470 dirpath, nodename = vcspath.split(node['path'])
470 dirpath, nodename = vcspath.split(node['path'])
471 dirnames = list(map(safe_str, dirpath and dirpath.split('/') or []))
471 dirnames = list(map(safe_str, dirpath and dirpath.split('/') or []))
472 parent = commit_tree
472 parent = commit_tree
473 ancestors = [('', parent)]
473 ancestors = [('', parent)]
474
474
475 # Tries to dig for the deepest existing tree
475 # Tries to dig for the deepest existing tree
476 while dirnames:
476 while dirnames:
477 curdir = dirnames.pop(0)
477 curdir = dirnames.pop(0)
478 try:
478 try:
479 dir_id = parent[curdir][1]
479 dir_id = parent[curdir][1]
480 except KeyError:
480 except KeyError:
481 # put curdir back into dirnames and stops
481 # put curdir back into dirnames and stops
482 dirnames.insert(0, curdir)
482 dirnames.insert(0, curdir)
483 break
483 break
484 else:
484 else:
485 # If found, updates parent
485 # If found, updates parent
486 parent = repo[dir_id]
486 parent = repo[dir_id]
487 ancestors.append((curdir, parent))
487 ancestors.append((curdir, parent))
488 # Now parent is deepest existing tree and we need to create
488 # Now parent is deepest existing tree and we need to create
489 # subtrees for dirnames (in reverse order)
489 # subtrees for dirnames (in reverse order)
490 # [this only applies for nodes from added]
490 # [this only applies for nodes from added]
491 new_trees = []
491 new_trees = []
492
492
493 blob = objects.Blob.from_string(node['content'])
493 blob = objects.Blob.from_string(node['content'])
494
494
495 if dirnames:
495 if dirnames:
496 # If there are trees which should be created we need to build
496 # If there are trees which should be created we need to build
497 # them now (in reverse order)
497 # them now (in reverse order)
498 reversed_dirnames = list(reversed(dirnames))
498 reversed_dirnames = list(reversed(dirnames))
499 curtree = objects.Tree()
499 curtree = objects.Tree()
500 curtree[node['node_path']] = node['mode'], blob.id
500 curtree[node['node_path']] = node['mode'], blob.id
501 new_trees.append(curtree)
501 new_trees.append(curtree)
502 for dirname in reversed_dirnames[:-1]:
502 for dirname in reversed_dirnames[:-1]:
503 newtree = objects.Tree()
503 newtree = objects.Tree()
504 newtree[dirname] = (DIR_STAT, curtree.id)
504 newtree[dirname] = (DIR_STAT, curtree.id)
505 new_trees.append(newtree)
505 new_trees.append(newtree)
506 curtree = newtree
506 curtree = newtree
507 parent[reversed_dirnames[-1]] = (DIR_STAT, curtree.id)
507 parent[reversed_dirnames[-1]] = (DIR_STAT, curtree.id)
508 else:
508 else:
509 parent.add(name=node['node_path'], mode=node['mode'], hexsha=blob.id)
509 parent.add(name=node['node_path'], mode=node['mode'], hexsha=blob.id)
510
510
511 new_trees.append(parent)
511 new_trees.append(parent)
512 # Update ancestors
512 # Update ancestors
513 reversed_ancestors = reversed(
513 reversed_ancestors = reversed(
514 [(a[1], b[1], b[0]) for a, b in zip(ancestors, ancestors[1:])])
514 [(a[1], b[1], b[0]) for a, b in zip(ancestors, ancestors[1:])])
515 for parent, tree, path in reversed_ancestors:
515 for parent, tree, path in reversed_ancestors:
516 parent[path] = (DIR_STAT, tree.id)
516 parent[path] = (DIR_STAT, tree.id)
517 object_store.add_object(tree)
517 object_store.add_object(tree)
518
518
519 object_store.add_object(blob)
519 object_store.add_object(blob)
520 for tree in new_trees:
520 for tree in new_trees:
521 object_store.add_object(tree)
521 object_store.add_object(tree)
522
522
523 for node_path in removed:
523 for node_path in removed:
524 paths = node_path.split('/')
524 paths = node_path.split('/')
525 tree = commit_tree # start with top-level
525 tree = commit_tree # start with top-level
526 trees = [{'tree': tree, 'path': ROOT}]
526 trees = [{'tree': tree, 'path': ROOT}]
527 # Traverse deep into the forest...
527 # Traverse deep into the forest...
528 # resolve final tree by iterating the path.
528 # resolve final tree by iterating the path.
529 # e.g a/b/c.txt will get
529 # e.g a/b/c.txt will get
530 # - root as tree then
530 # - root as tree then
531 # - 'a' as tree,
531 # - 'a' as tree,
532 # - 'b' as tree,
532 # - 'b' as tree,
533 # - stop at c as blob.
533 # - stop at c as blob.
534 for path in paths:
534 for path in paths:
535 try:
535 try:
536 obj = repo[tree[path][1]]
536 obj = repo[tree[path][1]]
537 if isinstance(obj, objects.Tree):
537 if isinstance(obj, objects.Tree):
538 trees.append({'tree': obj, 'path': path})
538 trees.append({'tree': obj, 'path': path})
539 tree = obj
539 tree = obj
540 except KeyError:
540 except KeyError:
541 break
541 break
542 #PROBLEM:
542 #PROBLEM:
543 """
543 """
544 We're not editing same reference tree object
544 We're not editing same reference tree object
545 """
545 """
546 # Cut down the blob and all rotten trees on the way back...
546 # Cut down the blob and all rotten trees on the way back...
547 for path, tree_data in reversed(list(zip(paths, trees))):
547 for path, tree_data in reversed(list(zip(paths, trees))):
548 tree = tree_data['tree']
548 tree = tree_data['tree']
549 tree.__delitem__(path)
549 tree.__delitem__(path)
550 # This operation edits the tree, we need to mark new commit back
550 # This operation edits the tree, we need to mark new commit back
551
551
552 if len(tree) > 0:
552 if len(tree) > 0:
553 # This tree still has elements - don't remove it or any
553 # This tree still has elements - don't remove it or any
554 # of it's parents
554 # of it's parents
555 break
555 break
556
556
557 object_store.add_object(commit_tree)
557 object_store.add_object(commit_tree)
558
558
559 # Create commit
559 # Create commit
560 commit = objects.Commit()
560 commit = objects.Commit()
561 commit.tree = commit_tree.id
561 commit.tree = commit_tree.id
562 for k, v in commit_data.items():
562 for k, v in commit_data.items():
563 setattr(commit, k, v)
563 setattr(commit, k, v)
564 object_store.add_object(commit)
564 object_store.add_object(commit)
565
565
566 self.create_branch(wire, branch, commit.id)
566 self.create_branch(wire, branch, commit.id)
567
567
568 # dulwich set-ref
568 # dulwich set-ref
569 ref = 'refs/heads/%s' % branch
569 ref = 'refs/heads/%s' % branch
570 repo.refs[ref] = commit.id
570 repo.refs[ref] = commit.id
571
571
572 return commit.id
572 return commit.id
573
573
574 @reraise_safe_exceptions
574 @reraise_safe_exceptions
575 def pull(self, wire, url, apply_refs=True, refs=None, update_after=False):
575 def pull(self, wire, url, apply_refs=True, refs=None, update_after=False):
576 if url != 'default' and '://' not in url:
576 if url != 'default' and '://' not in url:
577 client = LocalGitClient(url)
577 client = LocalGitClient(url)
578 else:
578 else:
579 url_obj = url_parser(url)
579 url_obj = url_parser(url)
580 o = self._build_opener(url)
580 o = self._build_opener(url)
581 url, _ = url_obj.authinfo()
581 url, _ = url_obj.authinfo()
582 client = HttpGitClient(base_url=url, opener=o)
582 client = HttpGitClient(base_url=url, opener=o)
583 repo = self._factory.repo(wire)
583 repo = self._factory.repo(wire)
584
584
585 determine_wants = repo.object_store.determine_wants_all
585 determine_wants = repo.object_store.determine_wants_all
586 if refs:
586 if refs:
587 def determine_wants_requested(references):
587 def determine_wants_requested(references):
588 return [references[r] for r in references if r in refs]
588 return [references[r] for r in references if r in refs]
589 determine_wants = determine_wants_requested
589 determine_wants = determine_wants_requested
590
590
591 try:
591 try:
592 remote_refs = client.fetch(
592 remote_refs = client.fetch(
593 path=url, target=repo, determine_wants=determine_wants)
593 path=url, target=repo, determine_wants=determine_wants)
594 except NotGitRepository as e:
594 except NotGitRepository as e:
595 log.warning(
595 log.warning(
596 'Trying to fetch from "%s" failed, not a Git repository.', url)
596 'Trying to fetch from "%s" failed, not a Git repository.', url)
597 # Exception can contain unicode which we convert
597 # Exception can contain unicode which we convert
598 raise exceptions.AbortException(e)(repr(e))
598 raise exceptions.AbortException(e)(repr(e))
599
599
600 # mikhail: client.fetch() returns all the remote refs, but fetches only
600 # mikhail: client.fetch() returns all the remote refs, but fetches only
601 # refs filtered by `determine_wants` function. We need to filter result
601 # refs filtered by `determine_wants` function. We need to filter result
602 # as well
602 # as well
603 if refs:
603 if refs:
604 remote_refs = {k: remote_refs[k] for k in remote_refs if k in refs}
604 remote_refs = {k: remote_refs[k] for k in remote_refs if k in refs}
605
605
606 if apply_refs:
606 if apply_refs:
607 # TODO: johbo: Needs proper test coverage with a git repository
607 # TODO: johbo: Needs proper test coverage with a git repository
608 # that contains a tag object, so that we would end up with
608 # that contains a tag object, so that we would end up with
609 # a peeled ref at this point.
609 # a peeled ref at this point.
610 for k in remote_refs:
610 for k in remote_refs:
611 if k.endswith(PEELED_REF_MARKER):
611 if k.endswith(PEELED_REF_MARKER):
612 log.debug("Skipping peeled reference %s", k)
612 log.debug("Skipping peeled reference %s", k)
613 continue
613 continue
614 repo[k] = remote_refs[k]
614 repo[k] = remote_refs[k]
615
615
616 if refs and not update_after:
616 if refs and not update_after:
617 # mikhail: explicitly set the head to the last ref.
617 # mikhail: explicitly set the head to the last ref.
618 repo["HEAD"] = remote_refs[refs[-1]]
618 repo["HEAD"] = remote_refs[refs[-1]]
619
619
620 if update_after:
620 if update_after:
621 # we want to checkout HEAD
621 # we want to checkout HEAD
622 repo["HEAD"] = remote_refs["HEAD"]
622 repo["HEAD"] = remote_refs["HEAD"]
623 index.build_index_from_tree(repo.path, repo.index_path(),
623 index.build_index_from_tree(repo.path, repo.index_path(),
624 repo.object_store, repo["HEAD"].tree)
624 repo.object_store, repo["HEAD"].tree)
625 return remote_refs
625 return remote_refs
626
626
627 @reraise_safe_exceptions
627 @reraise_safe_exceptions
628 def sync_fetch(self, wire, url, refs=None, all_refs=False):
628 def sync_fetch(self, wire, url, refs=None, all_refs=False):
629 repo = self._factory.repo(wire)
629 repo = self._factory.repo(wire)
630 if refs and not isinstance(refs, (list, tuple)):
630 if refs and not isinstance(refs, (list, tuple)):
631 refs = [refs]
631 refs = [refs]
632
632
633 config = self._wire_to_config(wire)
633 config = self._wire_to_config(wire)
634 # get all remote refs we'll use to fetch later
634 # get all remote refs we'll use to fetch later
635 cmd = ['ls-remote']
635 cmd = ['ls-remote']
636 if not all_refs:
636 if not all_refs:
637 cmd += ['--heads', '--tags']
637 cmd += ['--heads', '--tags']
638 cmd += [url]
638 cmd += [url]
639 output, __ = self.run_git_command(
639 output, __ = self.run_git_command(
640 wire, cmd, fail_on_stderr=False,
640 wire, cmd, fail_on_stderr=False,
641 _copts=self._remote_conf(config),
641 _copts=self._remote_conf(config),
642 extra_env={'GIT_TERMINAL_PROMPT': '0'})
642 extra_env={'GIT_TERMINAL_PROMPT': '0'})
643
643
644 remote_refs = collections.OrderedDict()
644 remote_refs = collections.OrderedDict()
645 fetch_refs = []
645 fetch_refs = []
646
646
647 for ref_line in output.splitlines():
647 for ref_line in output.splitlines():
648 sha, ref = ref_line.split('\t')
648 sha, ref = ref_line.split('\t')
649 sha = sha.strip()
649 sha = sha.strip()
650 if ref in remote_refs:
650 if ref in remote_refs:
651 # duplicate, skip
651 # duplicate, skip
652 continue
652 continue
653 if ref.endswith(PEELED_REF_MARKER):
653 if ref.endswith(PEELED_REF_MARKER):
654 log.debug("Skipping peeled reference %s", ref)
654 log.debug("Skipping peeled reference %s", ref)
655 continue
655 continue
656 # don't sync HEAD
656 # don't sync HEAD
657 if ref in ['HEAD']:
657 if ref in ['HEAD']:
658 continue
658 continue
659
659
660 remote_refs[ref] = sha
660 remote_refs[ref] = sha
661
661
662 if refs and sha in refs:
662 if refs and sha in refs:
663 # we filter fetch using our specified refs
663 # we filter fetch using our specified refs
664 fetch_refs.append('{}:{}'.format(ref, ref))
664 fetch_refs.append('{}:{}'.format(ref, ref))
665 elif not refs:
665 elif not refs:
666 fetch_refs.append('{}:{}'.format(ref, ref))
666 fetch_refs.append('{}:{}'.format(ref, ref))
667 log.debug('Finished obtaining fetch refs, total: %s', len(fetch_refs))
667 log.debug('Finished obtaining fetch refs, total: %s', len(fetch_refs))
668
668
669 if fetch_refs:
669 if fetch_refs:
670 for chunk in more_itertools.chunked(fetch_refs, 1024 * 4):
670 for chunk in more_itertools.chunked(fetch_refs, 1024 * 4):
671 fetch_refs_chunks = list(chunk)
671 fetch_refs_chunks = list(chunk)
672 log.debug('Fetching %s refs from import url', len(fetch_refs_chunks))
672 log.debug('Fetching %s refs from import url', len(fetch_refs_chunks))
673 self.run_git_command(
673 self.run_git_command(
674 wire, ['fetch', url, '--force', '--prune', '--'] + fetch_refs_chunks,
674 wire, ['fetch', url, '--force', '--prune', '--'] + fetch_refs_chunks,
675 fail_on_stderr=False,
675 fail_on_stderr=False,
676 _copts=self._remote_conf(config),
676 _copts=self._remote_conf(config),
677 extra_env={'GIT_TERMINAL_PROMPT': '0'})
677 extra_env={'GIT_TERMINAL_PROMPT': '0'})
678
678
679 return remote_refs
679 return remote_refs
680
680
681 @reraise_safe_exceptions
681 @reraise_safe_exceptions
682 def sync_push(self, wire, url, refs=None):
682 def sync_push(self, wire, url, refs=None):
683 if not self.check_url(url, wire):
683 if not self.check_url(url, wire):
684 return
684 return
685 config = self._wire_to_config(wire)
685 config = self._wire_to_config(wire)
686 self._factory.repo(wire)
686 self._factory.repo(wire)
687 self.run_git_command(
687 self.run_git_command(
688 wire, ['push', url, '--mirror'], fail_on_stderr=False,
688 wire, ['push', url, '--mirror'], fail_on_stderr=False,
689 _copts=self._remote_conf(config),
689 _copts=self._remote_conf(config),
690 extra_env={'GIT_TERMINAL_PROMPT': '0'})
690 extra_env={'GIT_TERMINAL_PROMPT': '0'})
691
691
692 @reraise_safe_exceptions
692 @reraise_safe_exceptions
693 def get_remote_refs(self, wire, url):
693 def get_remote_refs(self, wire, url):
694 repo = Repo(url)
694 repo = Repo(url)
695 return repo.get_refs()
695 return repo.get_refs()
696
696
697 @reraise_safe_exceptions
697 @reraise_safe_exceptions
698 def get_description(self, wire):
698 def get_description(self, wire):
699 repo = self._factory.repo(wire)
699 repo = self._factory.repo(wire)
700 return repo.get_description()
700 return repo.get_description()
701
701
702 @reraise_safe_exceptions
702 @reraise_safe_exceptions
703 def get_missing_revs(self, wire, rev1, rev2, path2):
703 def get_missing_revs(self, wire, rev1, rev2, path2):
704 repo = self._factory.repo(wire)
704 repo = self._factory.repo(wire)
705 LocalGitClient(thin_packs=False).fetch(path2, repo)
705 LocalGitClient(thin_packs=False).fetch(path2, repo)
706
706
707 wire_remote = wire.copy()
707 wire_remote = wire.copy()
708 wire_remote['path'] = path2
708 wire_remote['path'] = path2
709 repo_remote = self._factory.repo(wire_remote)
709 repo_remote = self._factory.repo(wire_remote)
710 LocalGitClient(thin_packs=False).fetch(wire["path"], repo_remote)
710 LocalGitClient(thin_packs=False).fetch(wire["path"], repo_remote)
711
711
712 revs = [
712 revs = [
713 x.commit.id
713 x.commit.id
714 for x in repo_remote.get_walker(include=[rev2], exclude=[rev1])]
714 for x in repo_remote.get_walker(include=[rev2], exclude=[rev1])]
715 return revs
715 return revs
716
716
717 @reraise_safe_exceptions
717 @reraise_safe_exceptions
718 def get_object(self, wire, sha, maybe_unreachable=False):
718 def get_object(self, wire, sha, maybe_unreachable=False):
719 cache_on, context_uid, repo_id = self._cache_on(wire)
719 cache_on, context_uid, repo_id = self._cache_on(wire)
720 region = self._region(wire)
720 region = self._region(wire)
721
721
722 @region.conditional_cache_on_arguments(condition=cache_on)
722 @region.conditional_cache_on_arguments(condition=cache_on)
723 def _get_object(_context_uid, _repo_id, _sha):
723 def _get_object(_context_uid, _repo_id, _sha):
724 repo_init = self._factory.repo_libgit2(wire)
724 repo_init = self._factory.repo_libgit2(wire)
725 with repo_init as repo:
725 with repo_init as repo:
726
726
727 missing_commit_err = 'Commit {} does not exist for `{}`'.format(sha, wire['path'])
727 missing_commit_err = 'Commit {} does not exist for `{}`'.format(sha, wire['path'])
728 try:
728 try:
729 commit = repo.revparse_single(sha)
729 commit = repo.revparse_single(sha)
730 except KeyError:
730 except KeyError:
731 # NOTE(marcink): KeyError doesn't give us any meaningful information
731 # NOTE(marcink): KeyError doesn't give us any meaningful information
732 # here, we instead give something more explicit
732 # here, we instead give something more explicit
733 e = exceptions.RefNotFoundException('SHA: %s not found', sha)
733 e = exceptions.RefNotFoundException('SHA: %s not found', sha)
734 raise exceptions.LookupException(e)(missing_commit_err)
734 raise exceptions.LookupException(e)(missing_commit_err)
735 except ValueError as e:
735 except ValueError as e:
736 raise exceptions.LookupException(e)(missing_commit_err)
736 raise exceptions.LookupException(e)(missing_commit_err)
737
737
738 is_tag = False
738 is_tag = False
739 if isinstance(commit, pygit2.Tag):
739 if isinstance(commit, pygit2.Tag):
740 commit = repo.get(commit.target)
740 commit = repo.get(commit.target)
741 is_tag = True
741 is_tag = True
742
742
743 check_dangling = True
743 check_dangling = True
744 if is_tag:
744 if is_tag:
745 check_dangling = False
745 check_dangling = False
746
746
747 if check_dangling and maybe_unreachable:
747 if check_dangling and maybe_unreachable:
748 check_dangling = False
748 check_dangling = False
749
749
750 # we used a reference and it parsed means we're not having a dangling commit
750 # we used a reference and it parsed means we're not having a dangling commit
751 if sha != commit.hex:
751 if sha != commit.hex:
752 check_dangling = False
752 check_dangling = False
753
753
754 if check_dangling:
754 if check_dangling:
755 # check for dangling commit
755 # check for dangling commit
756 for branch in repo.branches.with_commit(commit.hex):
756 for branch in repo.branches.with_commit(commit.hex):
757 if branch:
757 if branch:
758 break
758 break
759 else:
759 else:
760 # NOTE(marcink): Empty error doesn't give us any meaningful information
760 # NOTE(marcink): Empty error doesn't give us any meaningful information
761 # here, we instead give something more explicit
761 # here, we instead give something more explicit
762 e = exceptions.RefNotFoundException('SHA: %s not found in branches', sha)
762 e = exceptions.RefNotFoundException('SHA: %s not found in branches', sha)
763 raise exceptions.LookupException(e)(missing_commit_err)
763 raise exceptions.LookupException(e)(missing_commit_err)
764
764
765 commit_id = commit.hex
765 commit_id = commit.hex
766 type_id = commit.type
766 type_id = commit.type
767
767
768 return {
768 return {
769 'id': commit_id,
769 'id': commit_id,
770 'type': self._type_id_to_name(type_id),
770 'type': self._type_id_to_name(type_id),
771 'commit_id': commit_id,
771 'commit_id': commit_id,
772 'idx': 0
772 'idx': 0
773 }
773 }
774
774
775 return _get_object(context_uid, repo_id, sha)
775 return _get_object(context_uid, repo_id, sha)
776
776
777 @reraise_safe_exceptions
777 @reraise_safe_exceptions
778 def get_refs(self, wire):
778 def get_refs(self, wire):
779 cache_on, context_uid, repo_id = self._cache_on(wire)
779 cache_on, context_uid, repo_id = self._cache_on(wire)
780 region = self._region(wire)
780 region = self._region(wire)
781
781
782 @region.conditional_cache_on_arguments(condition=cache_on)
782 @region.conditional_cache_on_arguments(condition=cache_on)
783 def _get_refs(_context_uid, _repo_id):
783 def _get_refs(_context_uid, _repo_id):
784
784
785 repo_init = self._factory.repo_libgit2(wire)
785 repo_init = self._factory.repo_libgit2(wire)
786 with repo_init as repo:
786 with repo_init as repo:
787 regex = re.compile('^refs/(heads|tags)/')
787 regex = re.compile('^refs/(heads|tags)/')
788 return {x.name: x.target.hex for x in
788 return {x.name: x.target.hex for x in
789 [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]}
789 [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]}
790
790
791 return _get_refs(context_uid, repo_id)
791 return _get_refs(context_uid, repo_id)
792
792
793 @reraise_safe_exceptions
793 @reraise_safe_exceptions
794 def get_branch_pointers(self, wire):
794 def get_branch_pointers(self, wire):
795 cache_on, context_uid, repo_id = self._cache_on(wire)
795 cache_on, context_uid, repo_id = self._cache_on(wire)
796 region = self._region(wire)
796 region = self._region(wire)
797
797
798 @region.conditional_cache_on_arguments(condition=cache_on)
798 @region.conditional_cache_on_arguments(condition=cache_on)
799 def _get_branch_pointers(_context_uid, _repo_id):
799 def _get_branch_pointers(_context_uid, _repo_id):
800
800
801 repo_init = self._factory.repo_libgit2(wire)
801 repo_init = self._factory.repo_libgit2(wire)
802 regex = re.compile('^refs/heads')
802 regex = re.compile('^refs/heads')
803 with repo_init as repo:
803 with repo_init as repo:
804 branches = [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]
804 branches = [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]
805 return {x.target.hex: x.shorthand for x in branches}
805 return {x.target.hex: x.shorthand for x in branches}
806
806
807 return _get_branch_pointers(context_uid, repo_id)
807 return _get_branch_pointers(context_uid, repo_id)
808
808
809 @reraise_safe_exceptions
809 @reraise_safe_exceptions
810 def head(self, wire, show_exc=True):
810 def head(self, wire, show_exc=True):
811 cache_on, context_uid, repo_id = self._cache_on(wire)
811 cache_on, context_uid, repo_id = self._cache_on(wire)
812 region = self._region(wire)
812 region = self._region(wire)
813
813
814 @region.conditional_cache_on_arguments(condition=cache_on)
814 @region.conditional_cache_on_arguments(condition=cache_on)
815 def _head(_context_uid, _repo_id, _show_exc):
815 def _head(_context_uid, _repo_id, _show_exc):
816 repo_init = self._factory.repo_libgit2(wire)
816 repo_init = self._factory.repo_libgit2(wire)
817 with repo_init as repo:
817 with repo_init as repo:
818 try:
818 try:
819 return repo.head.peel().hex
819 return repo.head.peel().hex
820 except Exception:
820 except Exception:
821 if show_exc:
821 if show_exc:
822 raise
822 raise
823 return _head(context_uid, repo_id, show_exc)
823 return _head(context_uid, repo_id, show_exc)
824
824
825 @reraise_safe_exceptions
825 @reraise_safe_exceptions
826 def init(self, wire):
826 def init(self, wire):
827 repo_path = str_to_dulwich(wire['path'])
827 repo_path = safe_str(wire['path'])
828 self.repo = Repo.init(repo_path)
828 self.repo = Repo.init(repo_path)
829
829
830 @reraise_safe_exceptions
830 @reraise_safe_exceptions
831 def init_bare(self, wire):
831 def init_bare(self, wire):
832 repo_path = str_to_dulwich(wire['path'])
832 repo_path = safe_str(wire['path'])
833 self.repo = Repo.init_bare(repo_path)
833 self.repo = Repo.init_bare(repo_path)
834
834
835 @reraise_safe_exceptions
835 @reraise_safe_exceptions
836 def revision(self, wire, rev):
836 def revision(self, wire, rev):
837
837
838 cache_on, context_uid, repo_id = self._cache_on(wire)
838 cache_on, context_uid, repo_id = self._cache_on(wire)
839 region = self._region(wire)
839 region = self._region(wire)
840
840
841 @region.conditional_cache_on_arguments(condition=cache_on)
841 @region.conditional_cache_on_arguments(condition=cache_on)
842 def _revision(_context_uid, _repo_id, _rev):
842 def _revision(_context_uid, _repo_id, _rev):
843 repo_init = self._factory.repo_libgit2(wire)
843 repo_init = self._factory.repo_libgit2(wire)
844 with repo_init as repo:
844 with repo_init as repo:
845 commit = repo[rev]
845 commit = repo[rev]
846 obj_data = {
846 obj_data = {
847 'id': commit.id.hex,
847 'id': commit.id.hex,
848 }
848 }
849 # tree objects itself don't have tree_id attribute
849 # tree objects itself don't have tree_id attribute
850 if hasattr(commit, 'tree_id'):
850 if hasattr(commit, 'tree_id'):
851 obj_data['tree'] = commit.tree_id.hex
851 obj_data['tree'] = commit.tree_id.hex
852
852
853 return obj_data
853 return obj_data
854 return _revision(context_uid, repo_id, rev)
854 return _revision(context_uid, repo_id, rev)
855
855
856 @reraise_safe_exceptions
856 @reraise_safe_exceptions
857 def date(self, wire, commit_id):
857 def date(self, wire, commit_id):
858 cache_on, context_uid, repo_id = self._cache_on(wire)
858 cache_on, context_uid, repo_id = self._cache_on(wire)
859 region = self._region(wire)
859 region = self._region(wire)
860
860
861 @region.conditional_cache_on_arguments(condition=cache_on)
861 @region.conditional_cache_on_arguments(condition=cache_on)
862 def _date(_repo_id, _commit_id):
862 def _date(_repo_id, _commit_id):
863 repo_init = self._factory.repo_libgit2(wire)
863 repo_init = self._factory.repo_libgit2(wire)
864 with repo_init as repo:
864 with repo_init as repo:
865 commit = repo[commit_id]
865 commit = repo[commit_id]
866
866
867 if hasattr(commit, 'commit_time'):
867 if hasattr(commit, 'commit_time'):
868 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
868 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
869 else:
869 else:
870 commit = commit.get_object()
870 commit = commit.get_object()
871 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
871 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
872
872
873 # TODO(marcink): check dulwich difference of offset vs timezone
873 # TODO(marcink): check dulwich difference of offset vs timezone
874 return [commit_time, commit_time_offset]
874 return [commit_time, commit_time_offset]
875 return _date(repo_id, commit_id)
875 return _date(repo_id, commit_id)
876
876
877 @reraise_safe_exceptions
877 @reraise_safe_exceptions
878 def author(self, wire, commit_id):
878 def author(self, wire, commit_id):
879 cache_on, context_uid, repo_id = self._cache_on(wire)
879 cache_on, context_uid, repo_id = self._cache_on(wire)
880 region = self._region(wire)
880 region = self._region(wire)
881
881
882 @region.conditional_cache_on_arguments(condition=cache_on)
882 @region.conditional_cache_on_arguments(condition=cache_on)
883 def _author(_repo_id, _commit_id):
883 def _author(_repo_id, _commit_id):
884 repo_init = self._factory.repo_libgit2(wire)
884 repo_init = self._factory.repo_libgit2(wire)
885 with repo_init as repo:
885 with repo_init as repo:
886 commit = repo[commit_id]
886 commit = repo[commit_id]
887
887
888 if hasattr(commit, 'author'):
888 if hasattr(commit, 'author'):
889 author = commit.author
889 author = commit.author
890 else:
890 else:
891 author = commit.get_object().author
891 author = commit.get_object().author
892
892
893 if author.email:
893 if author.email:
894 return "{} <{}>".format(author.name, author.email)
894 return "{} <{}>".format(author.name, author.email)
895
895
896 try:
896 try:
897 return "{}".format(author.name)
897 return "{}".format(author.name)
898 except Exception:
898 except Exception:
899 return "{}".format(safe_str(author.raw_name))
899 return "{}".format(safe_str(author.raw_name))
900
900
901 return _author(repo_id, commit_id)
901 return _author(repo_id, commit_id)
902
902
903 @reraise_safe_exceptions
903 @reraise_safe_exceptions
904 def message(self, wire, commit_id):
904 def message(self, wire, commit_id):
905 cache_on, context_uid, repo_id = self._cache_on(wire)
905 cache_on, context_uid, repo_id = self._cache_on(wire)
906 region = self._region(wire)
906 region = self._region(wire)
907 @region.conditional_cache_on_arguments(condition=cache_on)
907 @region.conditional_cache_on_arguments(condition=cache_on)
908 def _message(_repo_id, _commit_id):
908 def _message(_repo_id, _commit_id):
909 repo_init = self._factory.repo_libgit2(wire)
909 repo_init = self._factory.repo_libgit2(wire)
910 with repo_init as repo:
910 with repo_init as repo:
911 commit = repo[commit_id]
911 commit = repo[commit_id]
912 return commit.message
912 return commit.message
913 return _message(repo_id, commit_id)
913 return _message(repo_id, commit_id)
914
914
915 @reraise_safe_exceptions
915 @reraise_safe_exceptions
916 def parents(self, wire, commit_id):
916 def parents(self, wire, commit_id):
917 cache_on, context_uid, repo_id = self._cache_on(wire)
917 cache_on, context_uid, repo_id = self._cache_on(wire)
918 region = self._region(wire)
918 region = self._region(wire)
919 @region.conditional_cache_on_arguments(condition=cache_on)
919 @region.conditional_cache_on_arguments(condition=cache_on)
920 def _parents(_repo_id, _commit_id):
920 def _parents(_repo_id, _commit_id):
921 repo_init = self._factory.repo_libgit2(wire)
921 repo_init = self._factory.repo_libgit2(wire)
922 with repo_init as repo:
922 with repo_init as repo:
923 commit = repo[commit_id]
923 commit = repo[commit_id]
924 if hasattr(commit, 'parent_ids'):
924 if hasattr(commit, 'parent_ids'):
925 parent_ids = commit.parent_ids
925 parent_ids = commit.parent_ids
926 else:
926 else:
927 parent_ids = commit.get_object().parent_ids
927 parent_ids = commit.get_object().parent_ids
928
928
929 return [x.hex for x in parent_ids]
929 return [x.hex for x in parent_ids]
930 return _parents(repo_id, commit_id)
930 return _parents(repo_id, commit_id)
931
931
932 @reraise_safe_exceptions
932 @reraise_safe_exceptions
933 def children(self, wire, commit_id):
933 def children(self, wire, commit_id):
934 cache_on, context_uid, repo_id = self._cache_on(wire)
934 cache_on, context_uid, repo_id = self._cache_on(wire)
935 region = self._region(wire)
935 region = self._region(wire)
936
936
937 @region.conditional_cache_on_arguments(condition=cache_on)
937 @region.conditional_cache_on_arguments(condition=cache_on)
938 def _children(_repo_id, _commit_id):
938 def _children(_repo_id, _commit_id):
939 output, __ = self.run_git_command(
939 output, __ = self.run_git_command(
940 wire, ['rev-list', '--all', '--children'])
940 wire, ['rev-list', '--all', '--children'])
941
941
942 child_ids = []
942 child_ids = []
943 pat = re.compile(r'^%s' % commit_id)
943 pat = re.compile(r'^%s' % commit_id)
944 for l in output.splitlines():
944 for l in output.splitlines():
945 if pat.match(l):
945 if pat.match(l):
946 found_ids = l.split(' ')[1:]
946 found_ids = l.split(' ')[1:]
947 child_ids.extend(found_ids)
947 child_ids.extend(found_ids)
948
948
949 return child_ids
949 return child_ids
950 return _children(repo_id, commit_id)
950 return _children(repo_id, commit_id)
951
951
952 @reraise_safe_exceptions
952 @reraise_safe_exceptions
953 def set_refs(self, wire, key, value):
953 def set_refs(self, wire, key, value):
954 repo_init = self._factory.repo_libgit2(wire)
954 repo_init = self._factory.repo_libgit2(wire)
955 with repo_init as repo:
955 with repo_init as repo:
956 repo.references.create(key, value, force=True)
956 repo.references.create(key, value, force=True)
957
957
958 @reraise_safe_exceptions
958 @reraise_safe_exceptions
959 def create_branch(self, wire, branch_name, commit_id, force=False):
959 def create_branch(self, wire, branch_name, commit_id, force=False):
960 repo_init = self._factory.repo_libgit2(wire)
960 repo_init = self._factory.repo_libgit2(wire)
961 with repo_init as repo:
961 with repo_init as repo:
962 commit = repo[commit_id]
962 commit = repo[commit_id]
963
963
964 if force:
964 if force:
965 repo.branches.local.create(branch_name, commit, force=force)
965 repo.branches.local.create(branch_name, commit, force=force)
966 elif not repo.branches.get(branch_name):
966 elif not repo.branches.get(branch_name):
967 # create only if that branch isn't existing
967 # create only if that branch isn't existing
968 repo.branches.local.create(branch_name, commit, force=force)
968 repo.branches.local.create(branch_name, commit, force=force)
969
969
970 @reraise_safe_exceptions
970 @reraise_safe_exceptions
971 def remove_ref(self, wire, key):
971 def remove_ref(self, wire, key):
972 repo_init = self._factory.repo_libgit2(wire)
972 repo_init = self._factory.repo_libgit2(wire)
973 with repo_init as repo:
973 with repo_init as repo:
974 repo.references.delete(key)
974 repo.references.delete(key)
975
975
976 @reraise_safe_exceptions
976 @reraise_safe_exceptions
977 def tag_remove(self, wire, tag_name):
977 def tag_remove(self, wire, tag_name):
978 repo_init = self._factory.repo_libgit2(wire)
978 repo_init = self._factory.repo_libgit2(wire)
979 with repo_init as repo:
979 with repo_init as repo:
980 key = 'refs/tags/{}'.format(tag_name)
980 key = 'refs/tags/{}'.format(tag_name)
981 repo.references.delete(key)
981 repo.references.delete(key)
982
982
983 @reraise_safe_exceptions
983 @reraise_safe_exceptions
984 def tree_changes(self, wire, source_id, target_id):
984 def tree_changes(self, wire, source_id, target_id):
985 # TODO(marcink): remove this seems it's only used by tests
985 # TODO(marcink): remove this seems it's only used by tests
986 repo = self._factory.repo(wire)
986 repo = self._factory.repo(wire)
987 source = repo[source_id].tree if source_id else None
987 source = repo[source_id].tree if source_id else None
988 target = repo[target_id].tree
988 target = repo[target_id].tree
989 result = repo.object_store.tree_changes(source, target)
989 result = repo.object_store.tree_changes(source, target)
990 return list(result)
990 return list(result)
991
991
992 @reraise_safe_exceptions
992 @reraise_safe_exceptions
993 def tree_and_type_for_path(self, wire, commit_id, path):
993 def tree_and_type_for_path(self, wire, commit_id, path):
994
994
995 cache_on, context_uid, repo_id = self._cache_on(wire)
995 cache_on, context_uid, repo_id = self._cache_on(wire)
996 region = self._region(wire)
996 region = self._region(wire)
997
997
998 @region.conditional_cache_on_arguments(condition=cache_on)
998 @region.conditional_cache_on_arguments(condition=cache_on)
999 def _tree_and_type_for_path(_context_uid, _repo_id, _commit_id, _path):
999 def _tree_and_type_for_path(_context_uid, _repo_id, _commit_id, _path):
1000 repo_init = self._factory.repo_libgit2(wire)
1000 repo_init = self._factory.repo_libgit2(wire)
1001
1001
1002 with repo_init as repo:
1002 with repo_init as repo:
1003 commit = repo[commit_id]
1003 commit = repo[commit_id]
1004 try:
1004 try:
1005 tree = commit.tree[path]
1005 tree = commit.tree[path]
1006 except KeyError:
1006 except KeyError:
1007 return None, None, None
1007 return None, None, None
1008
1008
1009 return tree.id.hex, tree.type_str, tree.filemode
1009 return tree.id.hex, tree.type_str, tree.filemode
1010 return _tree_and_type_for_path(context_uid, repo_id, commit_id, path)
1010 return _tree_and_type_for_path(context_uid, repo_id, commit_id, path)
1011
1011
1012 @reraise_safe_exceptions
1012 @reraise_safe_exceptions
1013 def tree_items(self, wire, tree_id):
1013 def tree_items(self, wire, tree_id):
1014 cache_on, context_uid, repo_id = self._cache_on(wire)
1014 cache_on, context_uid, repo_id = self._cache_on(wire)
1015 region = self._region(wire)
1015 region = self._region(wire)
1016
1016
1017 @region.conditional_cache_on_arguments(condition=cache_on)
1017 @region.conditional_cache_on_arguments(condition=cache_on)
1018 def _tree_items(_repo_id, _tree_id):
1018 def _tree_items(_repo_id, _tree_id):
1019
1019
1020 repo_init = self._factory.repo_libgit2(wire)
1020 repo_init = self._factory.repo_libgit2(wire)
1021 with repo_init as repo:
1021 with repo_init as repo:
1022 try:
1022 try:
1023 tree = repo[tree_id]
1023 tree = repo[tree_id]
1024 except KeyError:
1024 except KeyError:
1025 raise ObjectMissing('No tree with id: {}'.format(tree_id))
1025 raise ObjectMissing('No tree with id: {}'.format(tree_id))
1026
1026
1027 result = []
1027 result = []
1028 for item in tree:
1028 for item in tree:
1029 item_sha = item.hex
1029 item_sha = item.hex
1030 item_mode = item.filemode
1030 item_mode = item.filemode
1031 item_type = item.type_str
1031 item_type = item.type_str
1032
1032
1033 if item_type == 'commit':
1033 if item_type == 'commit':
1034 # NOTE(marcink): submodules we translate to 'link' for backward compat
1034 # NOTE(marcink): submodules we translate to 'link' for backward compat
1035 item_type = 'link'
1035 item_type = 'link'
1036
1036
1037 result.append((item.name, item_mode, item_sha, item_type))
1037 result.append((item.name, item_mode, item_sha, item_type))
1038 return result
1038 return result
1039 return _tree_items(repo_id, tree_id)
1039 return _tree_items(repo_id, tree_id)
1040
1040
1041 @reraise_safe_exceptions
1041 @reraise_safe_exceptions
1042 def diff_2(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1042 def diff_2(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1043 """
1043 """
1044 Old version that uses subprocess to call diff
1044 Old version that uses subprocess to call diff
1045 """
1045 """
1046
1046
1047 flags = [
1047 flags = [
1048 '-U%s' % context, '--patch',
1048 '-U%s' % context, '--patch',
1049 '--binary',
1049 '--binary',
1050 '--find-renames',
1050 '--find-renames',
1051 '--no-indent-heuristic',
1051 '--no-indent-heuristic',
1052 # '--indent-heuristic',
1052 # '--indent-heuristic',
1053 #'--full-index',
1053 #'--full-index',
1054 #'--abbrev=40'
1054 #'--abbrev=40'
1055 ]
1055 ]
1056
1056
1057 if opt_ignorews:
1057 if opt_ignorews:
1058 flags.append('--ignore-all-space')
1058 flags.append('--ignore-all-space')
1059
1059
1060 if commit_id_1 == self.EMPTY_COMMIT:
1060 if commit_id_1 == self.EMPTY_COMMIT:
1061 cmd = ['show'] + flags + [commit_id_2]
1061 cmd = ['show'] + flags + [commit_id_2]
1062 else:
1062 else:
1063 cmd = ['diff'] + flags + [commit_id_1, commit_id_2]
1063 cmd = ['diff'] + flags + [commit_id_1, commit_id_2]
1064
1064
1065 if file_filter:
1065 if file_filter:
1066 cmd.extend(['--', file_filter])
1066 cmd.extend(['--', file_filter])
1067
1067
1068 diff, __ = self.run_git_command(wire, cmd)
1068 diff, __ = self.run_git_command(wire, cmd)
1069 # If we used 'show' command, strip first few lines (until actual diff
1069 # If we used 'show' command, strip first few lines (until actual diff
1070 # starts)
1070 # starts)
1071 if commit_id_1 == self.EMPTY_COMMIT:
1071 if commit_id_1 == self.EMPTY_COMMIT:
1072 lines = diff.splitlines()
1072 lines = diff.splitlines()
1073 x = 0
1073 x = 0
1074 for line in lines:
1074 for line in lines:
1075 if line.startswith(b'diff'):
1075 if line.startswith(b'diff'):
1076 break
1076 break
1077 x += 1
1077 x += 1
1078 # Append new line just like 'diff' command do
1078 # Append new line just like 'diff' command do
1079 diff = '\n'.join(lines[x:]) + '\n'
1079 diff = '\n'.join(lines[x:]) + '\n'
1080 return diff
1080 return diff
1081
1081
1082 @reraise_safe_exceptions
1082 @reraise_safe_exceptions
1083 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1083 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1084 repo_init = self._factory.repo_libgit2(wire)
1084 repo_init = self._factory.repo_libgit2(wire)
1085 with repo_init as repo:
1085 with repo_init as repo:
1086 swap = True
1086 swap = True
1087 flags = 0
1087 flags = 0
1088 flags |= pygit2.GIT_DIFF_SHOW_BINARY
1088 flags |= pygit2.GIT_DIFF_SHOW_BINARY
1089
1089
1090 if opt_ignorews:
1090 if opt_ignorews:
1091 flags |= pygit2.GIT_DIFF_IGNORE_WHITESPACE
1091 flags |= pygit2.GIT_DIFF_IGNORE_WHITESPACE
1092
1092
1093 if commit_id_1 == self.EMPTY_COMMIT:
1093 if commit_id_1 == self.EMPTY_COMMIT:
1094 comm1 = repo[commit_id_2]
1094 comm1 = repo[commit_id_2]
1095 diff_obj = comm1.tree.diff_to_tree(
1095 diff_obj = comm1.tree.diff_to_tree(
1096 flags=flags, context_lines=context, swap=swap)
1096 flags=flags, context_lines=context, swap=swap)
1097
1097
1098 else:
1098 else:
1099 comm1 = repo[commit_id_2]
1099 comm1 = repo[commit_id_2]
1100 comm2 = repo[commit_id_1]
1100 comm2 = repo[commit_id_1]
1101 diff_obj = comm1.tree.diff_to_tree(
1101 diff_obj = comm1.tree.diff_to_tree(
1102 comm2.tree, flags=flags, context_lines=context, swap=swap)
1102 comm2.tree, flags=flags, context_lines=context, swap=swap)
1103 similar_flags = 0
1103 similar_flags = 0
1104 similar_flags |= pygit2.GIT_DIFF_FIND_RENAMES
1104 similar_flags |= pygit2.GIT_DIFF_FIND_RENAMES
1105 diff_obj.find_similar(flags=similar_flags)
1105 diff_obj.find_similar(flags=similar_flags)
1106
1106
1107 if file_filter:
1107 if file_filter:
1108 for p in diff_obj:
1108 for p in diff_obj:
1109 if p.delta.old_file.path == file_filter:
1109 if p.delta.old_file.path == file_filter:
1110 return p.patch or ''
1110 return p.patch or ''
1111 # fo matching path == no diff
1111 # fo matching path == no diff
1112 return ''
1112 return ''
1113 return diff_obj.patch or ''
1113 return diff_obj.patch or ''
1114
1114
1115 @reraise_safe_exceptions
1115 @reraise_safe_exceptions
1116 def node_history(self, wire, commit_id, path, limit):
1116 def node_history(self, wire, commit_id, path, limit):
1117 cache_on, context_uid, repo_id = self._cache_on(wire)
1117 cache_on, context_uid, repo_id = self._cache_on(wire)
1118 region = self._region(wire)
1118 region = self._region(wire)
1119
1119
1120 @region.conditional_cache_on_arguments(condition=cache_on)
1120 @region.conditional_cache_on_arguments(condition=cache_on)
1121 def _node_history(_context_uid, _repo_id, _commit_id, _path, _limit):
1121 def _node_history(_context_uid, _repo_id, _commit_id, _path, _limit):
1122 # optimize for n==1, rev-list is much faster for that use-case
1122 # optimize for n==1, rev-list is much faster for that use-case
1123 if limit == 1:
1123 if limit == 1:
1124 cmd = ['rev-list', '-1', commit_id, '--', path]
1124 cmd = ['rev-list', '-1', commit_id, '--', path]
1125 else:
1125 else:
1126 cmd = ['log']
1126 cmd = ['log']
1127 if limit:
1127 if limit:
1128 cmd.extend(['-n', str(safe_int(limit, 0))])
1128 cmd.extend(['-n', str(safe_int(limit, 0))])
1129 cmd.extend(['--pretty=format: %H', '-s', commit_id, '--', path])
1129 cmd.extend(['--pretty=format: %H', '-s', commit_id, '--', path])
1130
1130
1131 output, __ = self.run_git_command(wire, cmd)
1131 output, __ = self.run_git_command(wire, cmd)
1132 commit_ids = re.findall(rb'[0-9a-fA-F]{40}', output)
1132 commit_ids = re.findall(rb'[0-9a-fA-F]{40}', output)
1133
1133
1134 return [x for x in commit_ids]
1134 return [x for x in commit_ids]
1135 return _node_history(context_uid, repo_id, commit_id, path, limit)
1135 return _node_history(context_uid, repo_id, commit_id, path, limit)
1136
1136
1137 @reraise_safe_exceptions
1137 @reraise_safe_exceptions
1138 def node_annotate_legacy(self, wire, commit_id, path):
1138 def node_annotate_legacy(self, wire, commit_id, path):
1139 #note: replaced by pygit2 impelementation
1139 #note: replaced by pygit2 impelementation
1140 cmd = ['blame', '-l', '--root', '-r', commit_id, '--', path]
1140 cmd = ['blame', '-l', '--root', '-r', commit_id, '--', path]
1141 # -l ==> outputs long shas (and we need all 40 characters)
1141 # -l ==> outputs long shas (and we need all 40 characters)
1142 # --root ==> doesn't put '^' character for boundaries
1142 # --root ==> doesn't put '^' character for boundaries
1143 # -r commit_id ==> blames for the given commit
1143 # -r commit_id ==> blames for the given commit
1144 output, __ = self.run_git_command(wire, cmd)
1144 output, __ = self.run_git_command(wire, cmd)
1145
1145
1146 result = []
1146 result = []
1147 for i, blame_line in enumerate(output.splitlines()[:-1]):
1147 for i, blame_line in enumerate(output.splitlines()[:-1]):
1148 line_no = i + 1
1148 line_no = i + 1
1149 blame_commit_id, line = re.split(rb' ', blame_line, 1)
1149 blame_commit_id, line = re.split(rb' ', blame_line, 1)
1150 result.append((line_no, blame_commit_id, line))
1150 result.append((line_no, blame_commit_id, line))
1151
1151
1152 return result
1152 return result
1153
1153
1154 @reraise_safe_exceptions
1154 @reraise_safe_exceptions
1155 def node_annotate(self, wire, commit_id, path):
1155 def node_annotate(self, wire, commit_id, path):
1156
1156
1157 result_libgit = []
1157 result_libgit = []
1158 repo_init = self._factory.repo_libgit2(wire)
1158 repo_init = self._factory.repo_libgit2(wire)
1159 with repo_init as repo:
1159 with repo_init as repo:
1160 commit = repo[commit_id]
1160 commit = repo[commit_id]
1161 blame_obj = repo.blame(path, newest_commit=commit_id)
1161 blame_obj = repo.blame(path, newest_commit=commit_id)
1162 for i, line in enumerate(commit.tree[path].data.splitlines()):
1162 for i, line in enumerate(commit.tree[path].data.splitlines()):
1163 line_no = i + 1
1163 line_no = i + 1
1164 hunk = blame_obj.for_line(line_no)
1164 hunk = blame_obj.for_line(line_no)
1165 blame_commit_id = hunk.final_commit_id.hex
1165 blame_commit_id = hunk.final_commit_id.hex
1166
1166
1167 result_libgit.append((line_no, blame_commit_id, line))
1167 result_libgit.append((line_no, blame_commit_id, line))
1168
1168
1169 return result_libgit
1169 return result_libgit
1170
1170
1171 @reraise_safe_exceptions
1171 @reraise_safe_exceptions
1172 def update_server_info(self, wire):
1172 def update_server_info(self, wire):
1173 repo = self._factory.repo(wire)
1173 repo = self._factory.repo(wire)
1174 update_server_info(repo)
1174 update_server_info(repo)
1175
1175
1176 @reraise_safe_exceptions
1176 @reraise_safe_exceptions
1177 def get_all_commit_ids(self, wire):
1177 def get_all_commit_ids(self, wire):
1178
1178
1179 cache_on, context_uid, repo_id = self._cache_on(wire)
1179 cache_on, context_uid, repo_id = self._cache_on(wire)
1180 region = self._region(wire)
1180 region = self._region(wire)
1181
1181
1182 @region.conditional_cache_on_arguments(condition=cache_on)
1182 @region.conditional_cache_on_arguments(condition=cache_on)
1183 def _get_all_commit_ids(_context_uid, _repo_id):
1183 def _get_all_commit_ids(_context_uid, _repo_id):
1184
1184
1185 cmd = ['rev-list', '--reverse', '--date-order', '--branches', '--tags']
1185 cmd = ['rev-list', '--reverse', '--date-order', '--branches', '--tags']
1186 try:
1186 try:
1187 output, __ = self.run_git_command(wire, cmd)
1187 output, __ = self.run_git_command(wire, cmd)
1188 return output.splitlines()
1188 return output.splitlines()
1189 except Exception:
1189 except Exception:
1190 # Can be raised for empty repositories
1190 # Can be raised for empty repositories
1191 return []
1191 return []
1192
1192
1193 @region.conditional_cache_on_arguments(condition=cache_on)
1193 @region.conditional_cache_on_arguments(condition=cache_on)
1194 def _get_all_commit_ids_pygit2(_context_uid, _repo_id):
1194 def _get_all_commit_ids_pygit2(_context_uid, _repo_id):
1195 repo_init = self._factory.repo_libgit2(wire)
1195 repo_init = self._factory.repo_libgit2(wire)
1196 from pygit2 import GIT_SORT_REVERSE, GIT_SORT_TIME, GIT_BRANCH_ALL
1196 from pygit2 import GIT_SORT_REVERSE, GIT_SORT_TIME, GIT_BRANCH_ALL
1197 results = []
1197 results = []
1198 with repo_init as repo:
1198 with repo_init as repo:
1199 for commit in repo.walk(repo.head.target, GIT_SORT_TIME | GIT_BRANCH_ALL | GIT_SORT_REVERSE):
1199 for commit in repo.walk(repo.head.target, GIT_SORT_TIME | GIT_BRANCH_ALL | GIT_SORT_REVERSE):
1200 results.append(commit.id.hex)
1200 results.append(commit.id.hex)
1201
1201
1202 return _get_all_commit_ids(context_uid, repo_id)
1202 return _get_all_commit_ids(context_uid, repo_id)
1203
1203
1204 @reraise_safe_exceptions
1204 @reraise_safe_exceptions
1205 def run_git_command(self, wire, cmd, **opts):
1205 def run_git_command(self, wire, cmd, **opts):
1206 path = wire.get('path', None)
1206 path = wire.get('path', None)
1207
1207
1208 if path and os.path.isdir(path):
1208 if path and os.path.isdir(path):
1209 opts['cwd'] = path
1209 opts['cwd'] = path
1210
1210
1211 if '_bare' in opts:
1211 if '_bare' in opts:
1212 _copts = []
1212 _copts = []
1213 del opts['_bare']
1213 del opts['_bare']
1214 else:
1214 else:
1215 _copts = ['-c', 'core.quotepath=false', ]
1215 _copts = ['-c', 'core.quotepath=false', ]
1216 safe_call = False
1216 safe_call = False
1217 if '_safe' in opts:
1217 if '_safe' in opts:
1218 # no exc on failure
1218 # no exc on failure
1219 del opts['_safe']
1219 del opts['_safe']
1220 safe_call = True
1220 safe_call = True
1221
1221
1222 if '_copts' in opts:
1222 if '_copts' in opts:
1223 _copts.extend(opts['_copts'] or [])
1223 _copts.extend(opts['_copts'] or [])
1224 del opts['_copts']
1224 del opts['_copts']
1225
1225
1226 gitenv = os.environ.copy()
1226 gitenv = os.environ.copy()
1227 gitenv.update(opts.pop('extra_env', {}))
1227 gitenv.update(opts.pop('extra_env', {}))
1228 # need to clean fix GIT_DIR !
1228 # need to clean fix GIT_DIR !
1229 if 'GIT_DIR' in gitenv:
1229 if 'GIT_DIR' in gitenv:
1230 del gitenv['GIT_DIR']
1230 del gitenv['GIT_DIR']
1231 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
1231 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
1232 gitenv['GIT_DISCOVERY_ACROSS_FILESYSTEM'] = '1'
1232 gitenv['GIT_DISCOVERY_ACROSS_FILESYSTEM'] = '1'
1233
1233
1234 cmd = [settings.GIT_EXECUTABLE] + _copts + cmd
1234 cmd = [settings.GIT_EXECUTABLE] + _copts + cmd
1235 _opts = {'env': gitenv, 'shell': False}
1235 _opts = {'env': gitenv, 'shell': False}
1236
1236
1237 proc = None
1237 proc = None
1238 try:
1238 try:
1239 _opts.update(opts)
1239 _opts.update(opts)
1240 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
1240 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
1241
1241
1242 return b''.join(proc), b''.join(proc.stderr)
1242 return b''.join(proc), b''.join(proc.stderr)
1243 except OSError as err:
1243 except OSError as err:
1244 cmd = ' '.join(cmd) # human friendly CMD
1244 cmd = ' '.join(cmd) # human friendly CMD
1245 tb_err = ("Couldn't run git command (%s).\n"
1245 tb_err = ("Couldn't run git command (%s).\n"
1246 "Original error was:%s\n"
1246 "Original error was:%s\n"
1247 "Call options:%s\n"
1247 "Call options:%s\n"
1248 % (cmd, err, _opts))
1248 % (cmd, err, _opts))
1249 log.exception(tb_err)
1249 log.exception(tb_err)
1250 if safe_call:
1250 if safe_call:
1251 return '', err
1251 return '', err
1252 else:
1252 else:
1253 raise exceptions.VcsException()(tb_err)
1253 raise exceptions.VcsException()(tb_err)
1254 finally:
1254 finally:
1255 if proc:
1255 if proc:
1256 proc.close()
1256 proc.close()
1257
1257
1258 @reraise_safe_exceptions
1258 @reraise_safe_exceptions
1259 def install_hooks(self, wire, force=False):
1259 def install_hooks(self, wire, force=False):
1260 from vcsserver.hook_utils import install_git_hooks
1260 from vcsserver.hook_utils import install_git_hooks
1261 bare = self.bare(wire)
1261 bare = self.bare(wire)
1262 path = wire['path']
1262 path = wire['path']
1263 return install_git_hooks(path, bare, force_create=force)
1263 return install_git_hooks(path, bare, force_create=force)
1264
1264
1265 @reraise_safe_exceptions
1265 @reraise_safe_exceptions
1266 def get_hooks_info(self, wire):
1266 def get_hooks_info(self, wire):
1267 from vcsserver.hook_utils import (
1267 from vcsserver.hook_utils import (
1268 get_git_pre_hook_version, get_git_post_hook_version)
1268 get_git_pre_hook_version, get_git_post_hook_version)
1269 bare = self.bare(wire)
1269 bare = self.bare(wire)
1270 path = wire['path']
1270 path = wire['path']
1271 return {
1271 return {
1272 'pre_version': get_git_pre_hook_version(path, bare),
1272 'pre_version': get_git_pre_hook_version(path, bare),
1273 'post_version': get_git_post_hook_version(path, bare),
1273 'post_version': get_git_post_hook_version(path, bare),
1274 }
1274 }
1275
1275
1276 @reraise_safe_exceptions
1276 @reraise_safe_exceptions
1277 def set_head_ref(self, wire, head_name):
1277 def set_head_ref(self, wire, head_name):
1278 log.debug('Setting refs/head to `%s`', head_name)
1278 log.debug('Setting refs/head to `%s`', head_name)
1279 cmd = ['symbolic-ref', '"HEAD"', '"refs/heads/%s"' % head_name]
1279 cmd = ['symbolic-ref', '"HEAD"', '"refs/heads/%s"' % head_name]
1280 output, __ = self.run_git_command(wire, cmd)
1280 output, __ = self.run_git_command(wire, cmd)
1281 return [head_name] + output.splitlines()
1281 return [head_name] + output.splitlines()
1282
1282
1283 @reraise_safe_exceptions
1283 @reraise_safe_exceptions
1284 def archive_repo(self, wire, archive_dest_path, kind, mtime, archive_at_path,
1284 def archive_repo(self, wire, archive_dest_path, kind, mtime, archive_at_path,
1285 archive_dir_name, commit_id):
1285 archive_dir_name, commit_id):
1286
1286
1287 def file_walker(_commit_id, path):
1287 def file_walker(_commit_id, path):
1288 repo_init = self._factory.repo_libgit2(wire)
1288 repo_init = self._factory.repo_libgit2(wire)
1289
1289
1290 with repo_init as repo:
1290 with repo_init as repo:
1291 commit = repo[commit_id]
1291 commit = repo[commit_id]
1292
1292
1293 if path in ['', '/']:
1293 if path in ['', '/']:
1294 tree = commit.tree
1294 tree = commit.tree
1295 else:
1295 else:
1296 tree = commit.tree[path.rstrip('/')]
1296 tree = commit.tree[path.rstrip('/')]
1297 tree_id = tree.id.hex
1297 tree_id = tree.id.hex
1298 try:
1298 try:
1299 tree = repo[tree_id]
1299 tree = repo[tree_id]
1300 except KeyError:
1300 except KeyError:
1301 raise ObjectMissing('No tree with id: {}'.format(tree_id))
1301 raise ObjectMissing('No tree with id: {}'.format(tree_id))
1302
1302
1303 index = LibGit2Index.Index()
1303 index = LibGit2Index.Index()
1304 index.read_tree(tree)
1304 index.read_tree(tree)
1305 file_iter = index
1305 file_iter = index
1306
1306
1307 for fn in file_iter:
1307 for fn in file_iter:
1308 file_path = fn.path
1308 file_path = fn.path
1309 mode = fn.mode
1309 mode = fn.mode
1310 is_link = stat.S_ISLNK(mode)
1310 is_link = stat.S_ISLNK(mode)
1311 if mode == pygit2.GIT_FILEMODE_COMMIT:
1311 if mode == pygit2.GIT_FILEMODE_COMMIT:
1312 log.debug('Skipping path %s as a commit node', file_path)
1312 log.debug('Skipping path %s as a commit node', file_path)
1313 continue
1313 continue
1314 yield ArchiveNode(file_path, mode, is_link, repo[fn.hex].read_raw)
1314 yield ArchiveNode(file_path, mode, is_link, repo[fn.hex].read_raw)
1315
1315
1316 return archive_repo(file_walker, archive_dest_path, kind, mtime, archive_at_path,
1316 return archive_repo(file_walker, archive_dest_path, kind, mtime, archive_at_path,
1317 archive_dir_name, commit_id)
1317 archive_dir_name, commit_id)
@@ -1,155 +1,155 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2020 RhodeCode GmbH
2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18 import io
18 import io
19 import os
19 import os
20 import sys
20 import sys
21
21
22 import pytest
22 import pytest
23
23
24 from vcsserver import subprocessio
24 from vcsserver import subprocessio
25 from vcsserver.str_utils import ascii_bytes
25 from vcsserver.str_utils import ascii_bytes
26
26
27
27
28 class FileLikeObj(object): # pragma: no cover
28 class FileLikeObj(object): # pragma: no cover
29
29
30 def __init__(self, data: bytes, size):
30 def __init__(self, data: bytes, size):
31 chunks = size // len(data)
31 chunks = size // len(data)
32
32
33 self.stream = self._get_stream(data, chunks)
33 self.stream = self._get_stream(data, chunks)
34
34
35 def _get_stream(self, data, chunks):
35 def _get_stream(self, data, chunks):
36 for x in range(chunks):
36 for x in range(chunks):
37 yield data
37 yield data
38
38
39 def read(self, n):
39 def read(self, n):
40
40
41 buffer_stream = b''
41 buffer_stream = b''
42 for chunk in self.stream:
42 for chunk in self.stream:
43 buffer_stream += chunk
43 buffer_stream += chunk
44 if len(buffer_stream) >= n:
44 if len(buffer_stream) >= n:
45 break
45 break
46
46
47 # self.stream = self.bytes[n:]
47 # self.stream = self.bytes[n:]
48 return buffer_stream
48 return buffer_stream
49
49
50
50
51 @pytest.fixture(scope='module')
51 @pytest.fixture(scope='module')
52 def environ():
52 def environ():
53 """Delete coverage variables, as they make the tests fail."""
53 """Delete coverage variables, as they make the tests fail."""
54 env = dict(os.environ)
54 env = dict(os.environ)
55 for key in env.keys():
55 for key in list(env.keys()):
56 if key.startswith('COV_CORE_'):
56 if key.startswith('COV_CORE_'):
57 del env[key]
57 del env[key]
58
58
59 return env
59 return env
60
60
61
61
62 def _get_python_args(script):
62 def _get_python_args(script):
63 return [sys.executable, '-c', 'import sys; import time; import shutil; ' + script]
63 return [sys.executable, '-c', 'import sys; import time; import shutil; ' + script]
64
64
65
65
66 def test_raise_exception_on_non_zero_return_code(environ):
66 def test_raise_exception_on_non_zero_return_code(environ):
67 call_args = _get_python_args('raise ValueError("fail")')
67 call_args = _get_python_args('raise ValueError("fail")')
68 with pytest.raises(OSError):
68 with pytest.raises(OSError):
69 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
69 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
70
70
71
71
72 def test_does_not_fail_on_non_zero_return_code(environ):
72 def test_does_not_fail_on_non_zero_return_code(environ):
73 call_args = _get_python_args('sys.stdout.write("hello"); sys.exit(1)')
73 call_args = _get_python_args('sys.stdout.write("hello"); sys.exit(1)')
74 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
74 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
75 output = b''.join(proc)
75 output = b''.join(proc)
76
76
77 assert output == b'hello'
77 assert output == b'hello'
78
78
79
79
80 def test_raise_exception_on_stderr(environ):
80 def test_raise_exception_on_stderr(environ):
81 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); time.sleep(1);')
81 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); time.sleep(1);')
82
82
83 with pytest.raises(OSError) as excinfo:
83 with pytest.raises(OSError) as excinfo:
84 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
84 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
85
85
86 assert 'exited due to an error:\nWRITE_TO_STDERR' in str(excinfo.value)
86 assert 'exited due to an error:\nWRITE_TO_STDERR' in str(excinfo.value)
87
87
88
88
89 def test_does_not_fail_on_stderr(environ):
89 def test_does_not_fail_on_stderr(environ):
90 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); sys.stderr.flush; time.sleep(2);')
90 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); sys.stderr.flush; time.sleep(2);')
91 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_stderr=False, env=environ)
91 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_stderr=False, env=environ)
92 output = b''.join(proc)
92 output = b''.join(proc)
93
93
94 assert output == b''
94 assert output == b''
95
95
96
96
97 @pytest.mark.parametrize('size', [
97 @pytest.mark.parametrize('size', [
98 1,
98 1,
99 10 ** 5
99 10 ** 5
100 ])
100 ])
101 def test_output_with_no_input(size, environ):
101 def test_output_with_no_input(size, environ):
102 call_args = _get_python_args(f'sys.stdout.write("X" * {size});')
102 call_args = _get_python_args(f'sys.stdout.write("X" * {size});')
103 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ)
103 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ)
104 output = b''.join(proc)
104 output = b''.join(proc)
105
105
106 assert output == ascii_bytes("X" * size)
106 assert output == ascii_bytes("X" * size)
107
107
108
108
109 @pytest.mark.parametrize('size', [
109 @pytest.mark.parametrize('size', [
110 1,
110 1,
111 10 ** 5
111 10 ** 5
112 ])
112 ])
113 def test_output_with_no_input_does_not_fail(size, environ):
113 def test_output_with_no_input_does_not_fail(size, environ):
114
114
115 call_args = _get_python_args(f'sys.stdout.write("X" * {size}); sys.exit(1)')
115 call_args = _get_python_args(f'sys.stdout.write("X" * {size}); sys.exit(1)')
116 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
116 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
117 output = b''.join(proc)
117 output = b''.join(proc)
118
118
119 assert output == ascii_bytes("X" * size)
119 assert output == ascii_bytes("X" * size)
120
120
121
121
122 @pytest.mark.parametrize('size', [
122 @pytest.mark.parametrize('size', [
123 1,
123 1,
124 10 ** 5
124 10 ** 5
125 ])
125 ])
126 def test_output_with_input(size, environ):
126 def test_output_with_input(size, environ):
127 data_len = size
127 data_len = size
128 inputstream = FileLikeObj(b'X', size)
128 inputstream = FileLikeObj(b'X', size)
129
129
130 # This acts like the cat command.
130 # This acts like the cat command.
131 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
131 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
132 # note: in this tests we explicitly don't assign chunker to a variable and let it stream directly
132 # note: in this tests we explicitly don't assign chunker to a variable and let it stream directly
133 output = b''.join(
133 output = b''.join(
134 subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
134 subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
135 )
135 )
136
136
137 assert len(output) == data_len
137 assert len(output) == data_len
138
138
139
139
140 @pytest.mark.parametrize('size', [
140 @pytest.mark.parametrize('size', [
141 1,
141 1,
142 10 ** 5
142 10 ** 5
143 ])
143 ])
144 def test_output_with_input_skipping_iterator(size, environ):
144 def test_output_with_input_skipping_iterator(size, environ):
145 data_len = size
145 data_len = size
146 inputstream = FileLikeObj(b'X', size)
146 inputstream = FileLikeObj(b'X', size)
147
147
148 # This acts like the cat command.
148 # This acts like the cat command.
149 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
149 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
150
150
151 # Note: assigning the chunker makes sure that it is not deleted too early
151 # Note: assigning the chunker makes sure that it is not deleted too early
152 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
152 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
153 output = b''.join(proc.stdout)
153 output = b''.join(proc.stdout)
154
154
155 assert len(output) == data_len
155 assert len(output) == data_len
@@ -1,98 +1,98 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2020 RhodeCode GmbH
2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18 import wsgiref.simple_server
18 import wsgiref.simple_server
19 import wsgiref.validate
19 import wsgiref.validate
20
20
21 from vcsserver import wsgi_app_caller
21 from vcsserver import wsgi_app_caller
22 from vcsserver.str_utils import ascii_bytes, safe_str
22 from vcsserver.str_utils import ascii_bytes, safe_str
23
23
24
24
25 @wsgiref.validate.validator
25 @wsgiref.validate.validator
26 def demo_app(environ, start_response):
26 def demo_app(environ, start_response):
27 """WSGI app used for testing."""
27 """WSGI app used for testing."""
28
28
29 input_data = safe_str(environ['wsgi.input'].read(1024))
29 input_data = safe_str(environ['wsgi.input'].read(1024))
30
30
31 data = [
31 data = [
32 f'Hello World!\n',
32 'Hello World!\n',
33 f'input_data={input_data}\n',
33 f'input_data={input_data}\n',
34 ]
34 ]
35 for key, value in sorted(environ.items()):
35 for key, value in sorted(environ.items()):
36 data.append(f'{key}={value}\n')
36 data.append(f'{key}={value}\n')
37
37
38 write = start_response("200 OK", [('Content-Type', 'text/plain')])
38 write = start_response("200 OK", [('Content-Type', 'text/plain')])
39 write(b'Old school write method\n')
39 write(b'Old school write method\n')
40 write(b'***********************\n')
40 write(b'***********************\n')
41 return list(map(ascii_bytes, data))
41 return list(map(ascii_bytes, data))
42
42
43
43
44 BASE_ENVIRON = {
44 BASE_ENVIRON = {
45 'REQUEST_METHOD': 'GET',
45 'REQUEST_METHOD': 'GET',
46 'SERVER_NAME': 'localhost',
46 'SERVER_NAME': 'localhost',
47 'SERVER_PORT': '80',
47 'SERVER_PORT': '80',
48 'SCRIPT_NAME': '',
48 'SCRIPT_NAME': '',
49 'PATH_INFO': '/',
49 'PATH_INFO': '/',
50 'QUERY_STRING': '',
50 'QUERY_STRING': '',
51 'foo.var': 'bla',
51 'foo.var': 'bla',
52 }
52 }
53
53
54
54
55 def test_complete_environ():
55 def test_complete_environ():
56 environ = dict(BASE_ENVIRON)
56 environ = dict(BASE_ENVIRON)
57 data = b"data"
57 data = b"data"
58 wsgi_app_caller._complete_environ(environ, data)
58 wsgi_app_caller._complete_environ(environ, data)
59 wsgiref.validate.check_environ(environ)
59 wsgiref.validate.check_environ(environ)
60
60
61 assert data == environ['wsgi.input'].read(1024)
61 assert data == environ['wsgi.input'].read(1024)
62
62
63
63
64 def test_start_response():
64 def test_start_response():
65 start_response = wsgi_app_caller._StartResponse()
65 start_response = wsgi_app_caller._StartResponse()
66 status = '200 OK'
66 status = '200 OK'
67 headers = [('Content-Type', 'text/plain')]
67 headers = [('Content-Type', 'text/plain')]
68 start_response(status, headers)
68 start_response(status, headers)
69
69
70 assert status == start_response.status
70 assert status == start_response.status
71 assert headers == start_response.headers
71 assert headers == start_response.headers
72
72
73
73
74 def test_start_response_with_error():
74 def test_start_response_with_error():
75 start_response = wsgi_app_caller._StartResponse()
75 start_response = wsgi_app_caller._StartResponse()
76 status = '500 Internal Server Error'
76 status = '500 Internal Server Error'
77 headers = [('Content-Type', 'text/plain')]
77 headers = [('Content-Type', 'text/plain')]
78 start_response(status, headers, (None, None, None))
78 start_response(status, headers, (None, None, None))
79
79
80 assert status == start_response.status
80 assert status == start_response.status
81 assert headers == start_response.headers
81 assert headers == start_response.headers
82
82
83
83
84 def test_wsgi_app_caller():
84 def test_wsgi_app_caller():
85 environ = dict(BASE_ENVIRON)
85 environ = dict(BASE_ENVIRON)
86 input_data = 'some text'
86 input_data = 'some text'
87
87
88 caller = wsgi_app_caller.WSGIAppCaller(demo_app)
88 caller = wsgi_app_caller.WSGIAppCaller(demo_app)
89 responses, status, headers = caller.handle(environ, input_data)
89 responses, status, headers = caller.handle(environ, input_data)
90 response = b''.join(responses)
90 response = b''.join(responses)
91
91
92 assert status == '200 OK'
92 assert status == '200 OK'
93 assert headers == [('Content-Type', 'text/plain')]
93 assert headers == [('Content-Type', 'text/plain')]
94 assert response.startswith(b'Old school write method\n***********************\n')
94 assert response.startswith(b'Old school write method\n***********************\n')
95 assert b'Hello World!\n' in response
95 assert b'Hello World!\n' in response
96 assert b'foo.var=bla\n' in response
96 assert b'foo.var=bla\n' in response
97
97
98 assert ascii_bytes(f'input_data={input_data}\n') in response
98 assert ascii_bytes(f'input_data={input_data}\n') in response
General Comments 0
You need to be logged in to leave comments. Login now