##// END OF EJS Templates
fix for api key lookup, reuse same function in user model
marcink -
r1693:60249224 beta
parent child Browse files
Show More
@@ -1,248 +1,250 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 rhodecode.controllers.api
4 4 ~~~~~~~~~~~~~~~~~~~~~~~~~
5 5
6 6 JSON RPC controller
7 7
8 8 :created_on: Aug 20, 2011
9 9 :author: marcink
10 10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 11 :license: GPLv3, see COPYING for more details.
12 12 """
13 13 # This program is free software; you can redistribute it and/or
14 14 # modify it under the terms of the GNU General Public License
15 15 # as published by the Free Software Foundation; version 2
16 16 # of the License or (at your opinion) any later version of the license.
17 17 #
18 18 # This program is distributed in the hope that it will be useful,
19 19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 21 # GNU General Public License for more details.
22 22 #
23 23 # You should have received a copy of the GNU General Public License
24 24 # along with this program; if not, write to the Free Software
25 25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
26 26 # MA 02110-1301, USA.
27 27
28 28 import inspect
29 29 import logging
30 30 import types
31 31 import urllib
32 32 import traceback
33 33
34 34 from rhodecode.lib.compat import izip_longest, json
35 35
36 36 from paste.response import replace_header
37 37
38 38 from pylons.controllers import WSGIController
39 39
40 40
41 41 from webob.exc import HTTPNotFound, HTTPForbidden, HTTPInternalServerError, \
42 42 HTTPBadRequest, HTTPError
43 43
44 44 from rhodecode.model.db import User
45 45 from rhodecode.lib.auth import AuthUser
46 46
47 47 log = logging.getLogger('JSONRPC')
48 48
49 49 class JSONRPCError(BaseException):
50 50
51 51 def __init__(self, message):
52 52 self.message = message
53 53
54 54 def __str__(self):
55 55 return str(self.message)
56 56
57 57
58 58 def jsonrpc_error(message, code=None):
59 59 """
60 60 Generate a Response object with a JSON-RPC error body
61 61 """
62 62 from pylons.controllers.util import Response
63 63 resp = Response(body=json.dumps(dict(result=None, error=message)),
64 64 status=code,
65 65 content_type='application/json')
66 66 return resp
67 67
68 68
69 69
70 70 class JSONRPCController(WSGIController):
71 71 """
72 72 A WSGI-speaking JSON-RPC controller class
73 73
74 74 See the specification:
75 75 <http://json-rpc.org/wiki/specification>`.
76 76
77 77 Valid controller return values should be json-serializable objects.
78 78
79 79 Sub-classes should catch their exceptions and raise JSONRPCError
80 80 if they want to pass meaningful errors to the client.
81 81
82 82 """
83 83
84 84 def _get_method_args(self):
85 85 """
86 86 Return `self._rpc_args` to dispatched controller method
87 87 chosen by __call__
88 88 """
89 89 return self._rpc_args
90 90
91 91 def __call__(self, environ, start_response):
92 92 """
93 93 Parse the request body as JSON, look up the method on the
94 94 controller and if it exists, dispatch to it.
95 95 """
96 96 if 'CONTENT_LENGTH' not in environ:
97 97 log.debug("No Content-Length")
98 98 return jsonrpc_error(message="No Content-Length in request")
99 99 else:
100 100 length = environ['CONTENT_LENGTH'] or 0
101 101 length = int(environ['CONTENT_LENGTH'])
102 102 log.debug('Content-Length: %s', length)
103 103
104 104 if length == 0:
105 105 log.debug("Content-Length is 0")
106 106 return jsonrpc_error(message="Content-Length is 0")
107 107
108 108 raw_body = environ['wsgi.input'].read(length)
109 109
110 110 try:
111 111 json_body = json.loads(urllib.unquote_plus(raw_body))
112 112 except ValueError, e:
113 113 #catch JSON errors Here
114 114 return jsonrpc_error(message="JSON parse error ERR:%s RAW:%r" \
115 115 % (e, urllib.unquote_plus(raw_body)))
116 116
117 #check AUTH based on API KEY
117 # check AUTH based on API KEY
118 118 try:
119 119 self._req_api_key = json_body['api_key']
120 120 self._req_method = json_body['method']
121 121 self._req_params = json_body['args']
122 122 log.debug('method: %s, params: %s',
123 123 self._req_method,
124 124 self._req_params)
125 125 except KeyError, e:
126 126 return jsonrpc_error(message='Incorrect JSON query missing %s' % e)
127 127
128 #check if we can find this session using api_key
128 # check if we can find this session using api_key
129 129 try:
130 130 u = User.get_by_api_key(self._req_api_key)
131 if u is None:
132 return jsonrpc_error(message='Invalid API KEY')
131 133 auth_u = AuthUser(u.user_id, self._req_api_key)
132 134 except Exception, e:
133 135 return jsonrpc_error(message='Invalid API KEY')
134 136
135 137 self._error = None
136 138 try:
137 139 self._func = self._find_method()
138 140 except AttributeError, e:
139 141 return jsonrpc_error(message=str(e))
140 142
141 143 # now that we have a method, add self._req_params to
142 144 # self.kargs and dispatch control to WGIController
143 145 argspec = inspect.getargspec(self._func)
144 146 arglist = argspec[0][1:]
145 147 defaults = argspec[3] or []
146 148 default_empty = types.NotImplementedType
147 149
148 150 kwarglist = list(izip_longest(reversed(arglist), reversed(defaults),
149 151 fillvalue=default_empty))
150 152
151 153 # this is little trick to inject logged in user for
152 154 # perms decorators to work they expect the controller class to have
153 155 # rhodecode_user attribute set
154 156 self.rhodecode_user = auth_u
155 157
156 158 # This attribute will need to be first param of a method that uses
157 159 # api_key, which is translated to instance of user at that name
158 160 USER_SESSION_ATTR = 'apiuser'
159 161
160 162 if USER_SESSION_ATTR not in arglist:
161 163 return jsonrpc_error(message='This method [%s] does not support '
162 164 'authentication (missing %s param)' %
163 165 (self._func.__name__, USER_SESSION_ATTR))
164 166
165 167 # get our arglist and check if we provided them as args
166 168 for arg, default in kwarglist:
167 169 if arg == USER_SESSION_ATTR:
168 170 # USER_SESSION_ATTR is something translated from api key and
169 171 # this is checked before so we don't need validate it
170 172 continue
171 173
172 174 # skip the required param check if it's default value is
173 175 # NotImplementedType (default_empty)
174 176 if not self._req_params or (type(default) == default_empty
175 177 and arg not in self._req_params):
176 178 return jsonrpc_error(message=('Missing non optional %s arg '
177 179 'in JSON DATA') % arg)
178 180
179 181 self._rpc_args = {USER_SESSION_ATTR:u}
180 182 self._rpc_args.update(self._req_params)
181 183
182 184 self._rpc_args['action'] = self._req_method
183 185 self._rpc_args['environ'] = environ
184 186 self._rpc_args['start_response'] = start_response
185 187
186 188 status = []
187 189 headers = []
188 190 exc_info = []
189 191 def change_content(new_status, new_headers, new_exc_info=None):
190 192 status.append(new_status)
191 193 headers.extend(new_headers)
192 194 exc_info.append(new_exc_info)
193 195
194 196 output = WSGIController.__call__(self, environ, change_content)
195 197 output = list(output)
196 198 headers.append(('Content-Length', str(len(output[0]))))
197 199 replace_header(headers, 'Content-Type', 'application/json')
198 200 start_response(status[0], headers, exc_info[0])
199 201
200 202 return output
201 203
202 204 def _dispatch_call(self):
203 205 """
204 206 Implement dispatch interface specified by WSGIController
205 207 """
206 208 try:
207 209 raw_response = self._inspect_call(self._func)
208 210 if isinstance(raw_response, HTTPError):
209 211 self._error = str(raw_response)
210 212 except JSONRPCError, e:
211 213 self._error = str(e)
212 214 except Exception, e:
213 215 log.error('Encountered unhandled exception: %s' \
214 216 % traceback.format_exc())
215 217 json_exc = JSONRPCError('Internal server error')
216 218 self._error = str(json_exc)
217 219
218 220 if self._error is not None:
219 221 raw_response = None
220 222
221 223 response = dict(result=raw_response, error=self._error)
222 224
223 225 try:
224 226 return json.dumps(response)
225 227 except TypeError, e:
226 228 log.debug('Error encoding response: %s', e)
227 229 return json.dumps(dict(result=None,
228 230 error="Error encoding response"))
229 231
230 232 def _find_method(self):
231 233 """
232 234 Return method named by `self._req_method` in controller if able
233 235 """
234 236 log.debug('Trying to find JSON-RPC method: %s', self._req_method)
235 237 if self._req_method.startswith('_'):
236 238 raise AttributeError("Method not allowed")
237 239
238 240 try:
239 241 func = getattr(self, self._req_method, None)
240 242 except UnicodeEncodeError:
241 243 raise AttributeError("Problem decoding unicode in requested "
242 244 "method name.")
243 245
244 246 if isinstance(func, types.MethodType):
245 247 return func
246 248 else:
247 249 raise AttributeError("No such method: %s" % self._req_method)
248 250
@@ -1,1120 +1,1120 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 rhodecode.model.db
4 4 ~~~~~~~~~~~~~~~~~~
5 5
6 6 Database Models for RhodeCode
7 7
8 8 :created_on: Apr 08, 2010
9 9 :author: marcink
10 10 :copyright: (C) 2009-2011 Marcin Kuzminski <marcin@python-works.com>
11 11 :license: GPLv3, see COPYING for more details.
12 12 """
13 13 # This program is free software: you can redistribute it and/or modify
14 14 # it under the terms of the GNU General Public License as published by
15 15 # the Free Software Foundation, either version 3 of the License, or
16 16 # (at your option) any later version.
17 17 #
18 18 # This program is distributed in the hope that it will be useful,
19 19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 21 # GNU General Public License for more details.
22 22 #
23 23 # You should have received a copy of the GNU General Public License
24 24 # along with this program. If not, see <http://www.gnu.org/licenses/>.
25 25
26 26 import os
27 27 import logging
28 28 import datetime
29 29 import traceback
30 30 from datetime import date
31 31
32 32 from sqlalchemy import *
33 33 from sqlalchemy.exc import DatabaseError
34 34 from sqlalchemy.ext.hybrid import hybrid_property
35 35 from sqlalchemy.orm import relationship, joinedload, class_mapper, validates
36 36 from beaker.cache import cache_region, region_invalidate
37 37
38 38 from vcs import get_backend
39 39 from vcs.utils.helpers import get_scm
40 40 from vcs.exceptions import VCSError
41 41 from vcs.utils.lazy import LazyProperty
42 42
43 43 from rhodecode.lib import str2bool, safe_str, get_changeset_safe, \
44 44 generate_api_key, safe_unicode
45 45 from rhodecode.lib.exceptions import UsersGroupsAssignedException
46 46 from rhodecode.lib.compat import json
47 47 from rhodecode.lib.caching_query import FromCache
48 48
49 49 from rhodecode.model.meta import Base, Session
50 50
51 51
52 52
53 53 log = logging.getLogger(__name__)
54 54
55 55 #==============================================================================
56 56 # BASE CLASSES
57 57 #==============================================================================
58 58
59 59 class ModelSerializer(json.JSONEncoder):
60 60 """
61 61 Simple Serializer for JSON,
62 62
63 63 usage::
64 64
65 65 to make object customized for serialization implement a __json__
66 66 method that will return a dict for serialization into json
67 67
68 68 example::
69 69
70 70 class Task(object):
71 71
72 72 def __init__(self, name, value):
73 73 self.name = name
74 74 self.value = value
75 75
76 76 def __json__(self):
77 77 return dict(name=self.name,
78 78 value=self.value)
79 79
80 80 """
81 81
82 82 def default(self, obj):
83 83
84 84 if hasattr(obj, '__json__'):
85 85 return obj.__json__()
86 86 else:
87 87 return json.JSONEncoder.default(self, obj)
88 88
89 89 class BaseModel(object):
90 90 """Base Model for all classess
91 91
92 92 """
93 93
94 94 @classmethod
95 95 def _get_keys(cls):
96 96 """return column names for this model """
97 97 return class_mapper(cls).c.keys()
98 98
99 99 def get_dict(self):
100 100 """return dict with keys and values corresponding
101 101 to this model data """
102 102
103 103 d = {}
104 104 for k in self._get_keys():
105 105 d[k] = getattr(self, k)
106 106 return d
107 107
108 108 def get_appstruct(self):
109 109 """return list with keys and values tupples corresponding
110 110 to this model data """
111 111
112 112 l = []
113 113 for k in self._get_keys():
114 114 l.append((k, getattr(self, k),))
115 115 return l
116 116
117 117 def populate_obj(self, populate_dict):
118 118 """populate model with data from given populate_dict"""
119 119
120 120 for k in self._get_keys():
121 121 if k in populate_dict:
122 122 setattr(self, k, populate_dict[k])
123 123
124 124 @classmethod
125 125 def query(cls):
126 126 return Session.query(cls)
127 127
128 128 @classmethod
129 129 def get(cls, id_):
130 130 if id_:
131 131 return cls.query().get(id_)
132 132
133 133 @classmethod
134 134 def getAll(cls):
135 135 return cls.query().all()
136 136
137 137 @classmethod
138 138 def delete(cls, id_):
139 139 obj = cls.query().get(id_)
140 140 Session.delete(obj)
141 141 Session.commit()
142 142
143 143
144 144 class RhodeCodeSetting(Base, BaseModel):
145 145 __tablename__ = 'rhodecode_settings'
146 146 __table_args__ = (UniqueConstraint('app_settings_name'), {'extend_existing':True})
147 147 app_settings_id = Column("app_settings_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
148 148 app_settings_name = Column("app_settings_name", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
149 149 _app_settings_value = Column("app_settings_value", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
150 150
151 151 def __init__(self, k='', v=''):
152 152 self.app_settings_name = k
153 153 self.app_settings_value = v
154 154
155 155
156 156 @validates('_app_settings_value')
157 157 def validate_settings_value(self, key, val):
158 158 assert type(val) == unicode
159 159 return val
160 160
161 161 @hybrid_property
162 162 def app_settings_value(self):
163 163 v = self._app_settings_value
164 164 if v == 'ldap_active':
165 165 v = str2bool(v)
166 166 return v
167 167
168 168 @app_settings_value.setter
169 169 def app_settings_value(self, val):
170 170 """
171 171 Setter that will always make sure we use unicode in app_settings_value
172 172
173 173 :param val:
174 174 """
175 175 self._app_settings_value = safe_unicode(val)
176 176
177 177 def __repr__(self):
178 178 return "<%s('%s:%s')>" % (self.__class__.__name__,
179 179 self.app_settings_name, self.app_settings_value)
180 180
181 181
182 182 @classmethod
183 183 def get_by_name(cls, ldap_key):
184 184 return cls.query()\
185 185 .filter(cls.app_settings_name == ldap_key).scalar()
186 186
187 187 @classmethod
188 188 def get_app_settings(cls, cache=False):
189 189
190 190 ret = cls.query()
191 191
192 192 if cache:
193 193 ret = ret.options(FromCache("sql_cache_short", "get_hg_settings"))
194 194
195 195 if not ret:
196 196 raise Exception('Could not get application settings !')
197 197 settings = {}
198 198 for each in ret:
199 199 settings['rhodecode_' + each.app_settings_name] = \
200 200 each.app_settings_value
201 201
202 202 return settings
203 203
204 204 @classmethod
205 205 def get_ldap_settings(cls, cache=False):
206 206 ret = cls.query()\
207 207 .filter(cls.app_settings_name.startswith('ldap_')).all()
208 208 fd = {}
209 209 for row in ret:
210 210 fd.update({row.app_settings_name:row.app_settings_value})
211 211
212 212 return fd
213 213
214 214
215 215 class RhodeCodeUi(Base, BaseModel):
216 216 __tablename__ = 'rhodecode_ui'
217 217 __table_args__ = (UniqueConstraint('ui_key'), {'extend_existing':True})
218 218
219 219 HOOK_UPDATE = 'changegroup.update'
220 220 HOOK_REPO_SIZE = 'changegroup.repo_size'
221 221 HOOK_PUSH = 'pretxnchangegroup.push_logger'
222 222 HOOK_PULL = 'preoutgoing.pull_logger'
223 223
224 224 ui_id = Column("ui_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
225 225 ui_section = Column("ui_section", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
226 226 ui_key = Column("ui_key", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
227 227 ui_value = Column("ui_value", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
228 228 ui_active = Column("ui_active", Boolean(), nullable=True, unique=None, default=True)
229 229
230 230
231 231 @classmethod
232 232 def get_by_key(cls, key):
233 233 return cls.query().filter(cls.ui_key == key)
234 234
235 235
236 236 @classmethod
237 237 def get_builtin_hooks(cls):
238 238 q = cls.query()
239 239 q = q.filter(cls.ui_key.in_([cls.HOOK_UPDATE,
240 240 cls.HOOK_REPO_SIZE,
241 241 cls.HOOK_PUSH, cls.HOOK_PULL]))
242 242 return q.all()
243 243
244 244 @classmethod
245 245 def get_custom_hooks(cls):
246 246 q = cls.query()
247 247 q = q.filter(~cls.ui_key.in_([cls.HOOK_UPDATE,
248 248 cls.HOOK_REPO_SIZE,
249 249 cls.HOOK_PUSH, cls.HOOK_PULL]))
250 250 q = q.filter(cls.ui_section == 'hooks')
251 251 return q.all()
252 252
253 253 @classmethod
254 254 def create_or_update_hook(cls, key, val):
255 255 new_ui = cls.get_by_key(key).scalar() or cls()
256 256 new_ui.ui_section = 'hooks'
257 257 new_ui.ui_active = True
258 258 new_ui.ui_key = key
259 259 new_ui.ui_value = val
260 260
261 261 Session.add(new_ui)
262 262 Session.commit()
263 263
264 264
265 265 class User(Base, BaseModel):
266 266 __tablename__ = 'users'
267 267 __table_args__ = (UniqueConstraint('username'), UniqueConstraint('email'), {'extend_existing':True})
268 268 user_id = Column("user_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
269 269 username = Column("username", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
270 270 password = Column("password", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
271 271 active = Column("active", Boolean(), nullable=True, unique=None, default=None)
272 272 admin = Column("admin", Boolean(), nullable=True, unique=None, default=False)
273 273 name = Column("name", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
274 274 lastname = Column("lastname", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
275 275 email = Column("email", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
276 276 last_login = Column("last_login", DateTime(timezone=False), nullable=True, unique=None, default=None)
277 277 ldap_dn = Column("ldap_dn", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
278 278 api_key = Column("api_key", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
279 279
280 280 user_log = relationship('UserLog', cascade='all')
281 281 user_perms = relationship('UserToPerm', primaryjoin="User.user_id==UserToPerm.user_id", cascade='all')
282 282
283 283 repositories = relationship('Repository')
284 284 user_followers = relationship('UserFollowing', primaryjoin='UserFollowing.follows_user_id==User.user_id', cascade='all')
285 285 repo_to_perm = relationship('UserRepoToPerm', primaryjoin='UserRepoToPerm.user_id==User.user_id', cascade='all')
286 286
287 287 group_member = relationship('UsersGroupMember', cascade='all')
288 288
289 289 @property
290 290 def full_contact(self):
291 291 return '%s %s <%s>' % (self.name, self.lastname, self.email)
292 292
293 293 @property
294 294 def short_contact(self):
295 295 return '%s %s' % (self.name, self.lastname)
296 296
297 297 @property
298 298 def is_admin(self):
299 299 return self.admin
300 300
301 301 def __repr__(self):
302 302 try:
303 303 return "<%s('id:%s:%s')>" % (self.__class__.__name__,
304 304 self.user_id, self.username)
305 305 except:
306 306 return self.__class__.__name__
307 307
308 308 @classmethod
309 309 def get_by_username(cls, username, case_insensitive=False, cache=False):
310 310 if case_insensitive:
311 311 q = cls.query().filter(cls.username.ilike(username))
312 312 else:
313 313 q = cls.query().filter(cls.username == username)
314 314
315 315 if cache:
316 316 q = q.options(FromCache("sql_cache_short",
317 317 "get_user_%s" % username))
318 318 return q.scalar()
319 319
320 320 @classmethod
321 321 def get_by_api_key(cls, api_key, cache=False):
322 322 q = cls.query().filter(cls.api_key == api_key)
323 323
324 324 if cache:
325 325 q = q.options(FromCache("sql_cache_short",
326 326 "get_api_key_%s" % api_key))
327 q.one()
327 return q.scalar()
328 328
329 329 def update_lastlogin(self):
330 330 """Update user lastlogin"""
331 331
332 332 self.last_login = datetime.datetime.now()
333 333 Session.add(self)
334 334 Session.commit()
335 335 log.debug('updated user %s lastlogin', self.username)
336 336
337 337 class UserLog(Base, BaseModel):
338 338 __tablename__ = 'user_logs'
339 339 __table_args__ = {'extend_existing':True}
340 340 user_log_id = Column("user_log_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
341 341 user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
342 342 repository_id = Column("repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=False, unique=None, default=None)
343 343 repository_name = Column("repository_name", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
344 344 user_ip = Column("user_ip", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
345 345 action = Column("action", UnicodeText(length=1200000, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
346 346 action_date = Column("action_date", DateTime(timezone=False), nullable=True, unique=None, default=None)
347 347
348 348 @property
349 349 def action_as_day(self):
350 350 return date(*self.action_date.timetuple()[:3])
351 351
352 352 user = relationship('User')
353 353 repository = relationship('Repository')
354 354
355 355
356 356 class UsersGroup(Base, BaseModel):
357 357 __tablename__ = 'users_groups'
358 358 __table_args__ = {'extend_existing':True}
359 359
360 360 users_group_id = Column("users_group_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
361 361 users_group_name = Column("users_group_name", String(length=255, convert_unicode=False, assert_unicode=None), nullable=False, unique=True, default=None)
362 362 users_group_active = Column("users_group_active", Boolean(), nullable=True, unique=None, default=None)
363 363
364 364 members = relationship('UsersGroupMember', cascade="all, delete, delete-orphan", lazy="joined")
365 365
366 366 def __repr__(self):
367 367 return '<userGroup(%s)>' % (self.users_group_name)
368 368
369 369 @classmethod
370 370 def get_by_group_name(cls, group_name, cache=False, case_insensitive=False):
371 371 if case_insensitive:
372 372 gr = cls.query()\
373 373 .filter(cls.users_group_name.ilike(group_name))
374 374 else:
375 375 gr = cls.query()\
376 376 .filter(cls.users_group_name == group_name)
377 377 if cache:
378 378 gr = gr.options(FromCache("sql_cache_short",
379 379 "get_user_%s" % group_name))
380 380 return gr.scalar()
381 381
382 382
383 383 @classmethod
384 384 def get(cls, users_group_id, cache=False):
385 385 users_group = cls.query()
386 386 if cache:
387 387 users_group = users_group.options(FromCache("sql_cache_short",
388 388 "get_users_group_%s" % users_group_id))
389 389 return users_group.get(users_group_id)
390 390
391 391 @classmethod
392 392 def create(cls, form_data):
393 393 try:
394 394 new_users_group = cls()
395 395 for k, v in form_data.items():
396 396 setattr(new_users_group, k, v)
397 397
398 398 Session.add(new_users_group)
399 399 Session.commit()
400 400 return new_users_group
401 401 except:
402 402 log.error(traceback.format_exc())
403 403 Session.rollback()
404 404 raise
405 405
406 406 @classmethod
407 407 def update(cls, users_group_id, form_data):
408 408
409 409 try:
410 410 users_group = cls.get(users_group_id, cache=False)
411 411
412 412 for k, v in form_data.items():
413 413 if k == 'users_group_members':
414 414 users_group.members = []
415 415 Session.flush()
416 416 members_list = []
417 417 if v:
418 418 v = [v] if isinstance(v, basestring) else v
419 419 for u_id in set(v):
420 420 member = UsersGroupMember(users_group_id, u_id)
421 421 members_list.append(member)
422 422 setattr(users_group, 'members', members_list)
423 423 setattr(users_group, k, v)
424 424
425 425 Session.add(users_group)
426 426 Session.commit()
427 427 except:
428 428 log.error(traceback.format_exc())
429 429 Session.rollback()
430 430 raise
431 431
432 432 @classmethod
433 433 def delete(cls, users_group_id):
434 434 try:
435 435
436 436 # check if this group is not assigned to repo
437 437 assigned_groups = UsersGroupRepoToPerm.query()\
438 438 .filter(UsersGroupRepoToPerm.users_group_id ==
439 439 users_group_id).all()
440 440
441 441 if assigned_groups:
442 442 raise UsersGroupsAssignedException('RepoGroup assigned to %s' %
443 443 assigned_groups)
444 444
445 445 users_group = cls.get(users_group_id, cache=False)
446 446 Session.delete(users_group)
447 447 Session.commit()
448 448 except:
449 449 log.error(traceback.format_exc())
450 450 Session.rollback()
451 451 raise
452 452
453 453 class UsersGroupMember(Base, BaseModel):
454 454 __tablename__ = 'users_groups_members'
455 455 __table_args__ = {'extend_existing':True}
456 456
457 457 users_group_member_id = Column("users_group_member_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
458 458 users_group_id = Column("users_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
459 459 user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
460 460
461 461 user = relationship('User', lazy='joined')
462 462 users_group = relationship('UsersGroup')
463 463
464 464 def __init__(self, gr_id='', u_id=''):
465 465 self.users_group_id = gr_id
466 466 self.user_id = u_id
467 467
468 468 @staticmethod
469 469 def add_user_to_group(group, user):
470 470 ugm = UsersGroupMember()
471 471 ugm.users_group = group
472 472 ugm.user = user
473 473 Session.add(ugm)
474 474 Session.commit()
475 475 return ugm
476 476
477 477 class Repository(Base, BaseModel):
478 478 __tablename__ = 'repositories'
479 479 __table_args__ = (UniqueConstraint('repo_name'), {'extend_existing':True},)
480 480
481 481 repo_id = Column("repo_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
482 482 repo_name = Column("repo_name", String(length=255, convert_unicode=False, assert_unicode=None), nullable=False, unique=True, default=None)
483 483 clone_uri = Column("clone_uri", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=False, default=None)
484 484 repo_type = Column("repo_type", String(length=255, convert_unicode=False, assert_unicode=None), nullable=False, unique=False, default='hg')
485 485 user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=False, default=None)
486 486 private = Column("private", Boolean(), nullable=True, unique=None, default=None)
487 487 enable_statistics = Column("statistics", Boolean(), nullable=True, unique=None, default=True)
488 488 enable_downloads = Column("downloads", Boolean(), nullable=True, unique=None, default=True)
489 489 description = Column("description", String(length=10000, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
490 490 created_on = Column('created_on', DateTime(timezone=False), nullable=True, unique=None, default=datetime.datetime.now)
491 491
492 492 fork_id = Column("fork_id", Integer(), ForeignKey('repositories.repo_id'), nullable=True, unique=False, default=None)
493 493 group_id = Column("group_id", Integer(), ForeignKey('groups.group_id'), nullable=True, unique=False, default=None)
494 494
495 495
496 496 user = relationship('User')
497 497 fork = relationship('Repository', remote_side=repo_id)
498 498 group = relationship('RepoGroup')
499 499 repo_to_perm = relationship('UserRepoToPerm', cascade='all', order_by='UserRepoToPerm.repo_to_perm_id')
500 500 users_group_to_perm = relationship('UsersGroupRepoToPerm', cascade='all')
501 501 stats = relationship('Statistics', cascade='all', uselist=False)
502 502
503 503 followers = relationship('UserFollowing', primaryjoin='UserFollowing.follows_repo_id==Repository.repo_id', cascade='all')
504 504
505 505 logs = relationship('UserLog', cascade='all')
506 506
507 507 def __repr__(self):
508 508 return "<%s('%s:%s')>" % (self.__class__.__name__,
509 509 self.repo_id, self.repo_name)
510 510
511 511 @classmethod
512 512 def url_sep(cls):
513 513 return '/'
514 514
515 515 @classmethod
516 516 def get_by_repo_name(cls, repo_name):
517 517 q = Session.query(cls).filter(cls.repo_name == repo_name)
518 518 q = q.options(joinedload(Repository.fork))\
519 519 .options(joinedload(Repository.user))\
520 520 .options(joinedload(Repository.group))
521 521 return q.one()
522 522
523 523 @classmethod
524 524 def get_repo_forks(cls, repo_id):
525 525 return cls.query().filter(Repository.fork_id == repo_id)
526 526
527 527 @classmethod
528 528 def base_path(cls):
529 529 """
530 530 Returns base path when all repos are stored
531 531
532 532 :param cls:
533 533 """
534 534 q = Session.query(RhodeCodeUi).filter(RhodeCodeUi.ui_key ==
535 535 cls.url_sep())
536 536 q.options(FromCache("sql_cache_short", "repository_repo_path"))
537 537 return q.one().ui_value
538 538
539 539 @property
540 540 def just_name(self):
541 541 return self.repo_name.split(Repository.url_sep())[-1]
542 542
543 543 @property
544 544 def groups_with_parents(self):
545 545 groups = []
546 546 if self.group is None:
547 547 return groups
548 548
549 549 cur_gr = self.group
550 550 groups.insert(0, cur_gr)
551 551 while 1:
552 552 gr = getattr(cur_gr, 'parent_group', None)
553 553 cur_gr = cur_gr.parent_group
554 554 if gr is None:
555 555 break
556 556 groups.insert(0, gr)
557 557
558 558 return groups
559 559
560 560 @property
561 561 def groups_and_repo(self):
562 562 return self.groups_with_parents, self.just_name
563 563
564 564 @LazyProperty
565 565 def repo_path(self):
566 566 """
567 567 Returns base full path for that repository means where it actually
568 568 exists on a filesystem
569 569 """
570 570 q = Session.query(RhodeCodeUi).filter(RhodeCodeUi.ui_key ==
571 571 Repository.url_sep())
572 572 q.options(FromCache("sql_cache_short", "repository_repo_path"))
573 573 return q.one().ui_value
574 574
575 575 @property
576 576 def repo_full_path(self):
577 577 p = [self.repo_path]
578 578 # we need to split the name by / since this is how we store the
579 579 # names in the database, but that eventually needs to be converted
580 580 # into a valid system path
581 581 p += self.repo_name.split(Repository.url_sep())
582 582 return os.path.join(*p)
583 583
584 584 def get_new_name(self, repo_name):
585 585 """
586 586 returns new full repository name based on assigned group and new new
587 587
588 588 :param group_name:
589 589 """
590 590 path_prefix = self.group.full_path_splitted if self.group else []
591 591 return Repository.url_sep().join(path_prefix + [repo_name])
592 592
593 593 @property
594 594 def _ui(self):
595 595 """
596 596 Creates an db based ui object for this repository
597 597 """
598 598 from mercurial import ui
599 599 from mercurial import config
600 600 baseui = ui.ui()
601 601
602 602 #clean the baseui object
603 603 baseui._ocfg = config.config()
604 604 baseui._ucfg = config.config()
605 605 baseui._tcfg = config.config()
606 606
607 607
608 608 ret = RhodeCodeUi.query()\
609 609 .options(FromCache("sql_cache_short", "repository_repo_ui")).all()
610 610
611 611 hg_ui = ret
612 612 for ui_ in hg_ui:
613 613 if ui_.ui_active:
614 614 log.debug('settings ui from db[%s]%s:%s', ui_.ui_section,
615 615 ui_.ui_key, ui_.ui_value)
616 616 baseui.setconfig(ui_.ui_section, ui_.ui_key, ui_.ui_value)
617 617
618 618 return baseui
619 619
620 620 @classmethod
621 621 def is_valid(cls, repo_name):
622 622 """
623 623 returns True if given repo name is a valid filesystem repository
624 624
625 625 @param cls:
626 626 @param repo_name:
627 627 """
628 628 from rhodecode.lib.utils import is_valid_repo
629 629
630 630 return is_valid_repo(repo_name, cls.base_path())
631 631
632 632
633 633 #==========================================================================
634 634 # SCM PROPERTIES
635 635 #==========================================================================
636 636
637 637 def get_changeset(self, rev):
638 638 return get_changeset_safe(self.scm_instance, rev)
639 639
640 640 @property
641 641 def tip(self):
642 642 return self.get_changeset('tip')
643 643
644 644 @property
645 645 def author(self):
646 646 return self.tip.author
647 647
648 648 @property
649 649 def last_change(self):
650 650 return self.scm_instance.last_change
651 651
652 652 #==========================================================================
653 653 # SCM CACHE INSTANCE
654 654 #==========================================================================
655 655
656 656 @property
657 657 def invalidate(self):
658 658 return CacheInvalidation.invalidate(self.repo_name)
659 659
660 660 def set_invalidate(self):
661 661 """
662 662 set a cache for invalidation for this instance
663 663 """
664 664 CacheInvalidation.set_invalidate(self.repo_name)
665 665
666 666 @LazyProperty
667 667 def scm_instance(self):
668 668 return self.__get_instance()
669 669
670 670 @property
671 671 def scm_instance_cached(self):
672 672 @cache_region('long_term')
673 673 def _c(repo_name):
674 674 return self.__get_instance()
675 675 rn = self.repo_name
676 676
677 677 inv = self.invalidate
678 678 if inv is not None:
679 679 region_invalidate(_c, None, rn)
680 680 # update our cache
681 681 CacheInvalidation.set_valid(inv.cache_key)
682 682 return _c(rn)
683 683
684 684 def __get_instance(self):
685 685
686 686 repo_full_path = self.repo_full_path
687 687
688 688 try:
689 689 alias = get_scm(repo_full_path)[0]
690 690 log.debug('Creating instance of %s repository', alias)
691 691 backend = get_backend(alias)
692 692 except VCSError:
693 693 log.error(traceback.format_exc())
694 694 log.error('Perhaps this repository is in db and not in '
695 695 'filesystem run rescan repositories with '
696 696 '"destroy old data " option from admin panel')
697 697 return
698 698
699 699 if alias == 'hg':
700 700
701 701 repo = backend(safe_str(repo_full_path), create=False,
702 702 baseui=self._ui)
703 703 # skip hidden web repository
704 704 if repo._get_hidden():
705 705 return
706 706 else:
707 707 repo = backend(repo_full_path, create=False)
708 708
709 709 return repo
710 710
711 711
712 712 class RepoGroup(Base, BaseModel):
713 713 __tablename__ = 'groups'
714 714 __table_args__ = (UniqueConstraint('group_name', 'group_parent_id'),
715 715 CheckConstraint('group_id != group_parent_id'), {'extend_existing':True},)
716 716 __mapper_args__ = {'order_by':'group_name'}
717 717
718 718 group_id = Column("group_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
719 719 group_name = Column("group_name", String(length=255, convert_unicode=False, assert_unicode=None), nullable=False, unique=True, default=None)
720 720 group_parent_id = Column("group_parent_id", Integer(), ForeignKey('groups.group_id'), nullable=True, unique=None, default=None)
721 721 group_description = Column("group_description", String(length=10000, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
722 722
723 723 parent_group = relationship('RepoGroup', remote_side=group_id)
724 724
725 725
726 726 def __init__(self, group_name='', parent_group=None):
727 727 self.group_name = group_name
728 728 self.parent_group = parent_group
729 729
730 730 def __repr__(self):
731 731 return "<%s('%s:%s')>" % (self.__class__.__name__, self.group_id,
732 732 self.group_name)
733 733
734 734 @classmethod
735 735 def groups_choices(cls):
736 736 from webhelpers.html import literal as _literal
737 737 repo_groups = [('', '')]
738 738 sep = ' &raquo; '
739 739 _name = lambda k: _literal(sep.join(k))
740 740
741 741 repo_groups.extend([(x.group_id, _name(x.full_path_splitted))
742 742 for x in cls.query().all()])
743 743
744 744 repo_groups = sorted(repo_groups, key=lambda t: t[1].split(sep)[0])
745 745 return repo_groups
746 746
747 747 @classmethod
748 748 def url_sep(cls):
749 749 return '/'
750 750
751 751 @classmethod
752 752 def get_by_group_name(cls, group_name, cache=False, case_insensitive=False):
753 753 if case_insensitive:
754 754 gr = cls.query()\
755 755 .filter(cls.group_name.ilike(group_name))
756 756 else:
757 757 gr = cls.query()\
758 758 .filter(cls.group_name == group_name)
759 759 if cache:
760 760 gr = gr.options(FromCache("sql_cache_short",
761 761 "get_group_%s" % group_name))
762 762 return gr.scalar()
763 763
764 764 @property
765 765 def parents(self):
766 766 parents_recursion_limit = 5
767 767 groups = []
768 768 if self.parent_group is None:
769 769 return groups
770 770 cur_gr = self.parent_group
771 771 groups.insert(0, cur_gr)
772 772 cnt = 0
773 773 while 1:
774 774 cnt += 1
775 775 gr = getattr(cur_gr, 'parent_group', None)
776 776 cur_gr = cur_gr.parent_group
777 777 if gr is None:
778 778 break
779 779 if cnt == parents_recursion_limit:
780 780 # this will prevent accidental infinit loops
781 781 log.error('group nested more than %s' %
782 782 parents_recursion_limit)
783 783 break
784 784
785 785 groups.insert(0, gr)
786 786 return groups
787 787
788 788 @property
789 789 def children(self):
790 790 return RepoGroup.query().filter(RepoGroup.parent_group == self)
791 791
792 792 @property
793 793 def name(self):
794 794 return self.group_name.split(RepoGroup.url_sep())[-1]
795 795
796 796 @property
797 797 def full_path(self):
798 798 return self.group_name
799 799
800 800 @property
801 801 def full_path_splitted(self):
802 802 return self.group_name.split(RepoGroup.url_sep())
803 803
804 804 @property
805 805 def repositories(self):
806 806 return Repository.query().filter(Repository.group == self)
807 807
808 808 @property
809 809 def repositories_recursive_count(self):
810 810 cnt = self.repositories.count()
811 811
812 812 def children_count(group):
813 813 cnt = 0
814 814 for child in group.children:
815 815 cnt += child.repositories.count()
816 816 cnt += children_count(child)
817 817 return cnt
818 818
819 819 return cnt + children_count(self)
820 820
821 821
822 822 def get_new_name(self, group_name):
823 823 """
824 824 returns new full group name based on parent and new name
825 825
826 826 :param group_name:
827 827 """
828 828 path_prefix = (self.parent_group.full_path_splitted if
829 829 self.parent_group else [])
830 830 return RepoGroup.url_sep().join(path_prefix + [group_name])
831 831
832 832
833 833 class Permission(Base, BaseModel):
834 834 __tablename__ = 'permissions'
835 835 __table_args__ = {'extend_existing':True}
836 836 permission_id = Column("permission_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
837 837 permission_name = Column("permission_name", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
838 838 permission_longname = Column("permission_longname", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
839 839
840 840 def __repr__(self):
841 841 return "<%s('%s:%s')>" % (self.__class__.__name__,
842 842 self.permission_id, self.permission_name)
843 843
844 844 @classmethod
845 845 def get_by_key(cls, key):
846 846 return cls.query().filter(cls.permission_name == key).scalar()
847 847
848 848 class UserRepoToPerm(Base, BaseModel):
849 849 __tablename__ = 'repo_to_perm'
850 850 __table_args__ = (UniqueConstraint('user_id', 'repository_id'), {'extend_existing':True})
851 851 repo_to_perm_id = Column("repo_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
852 852 user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
853 853 permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
854 854 repository_id = Column("repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=False, unique=None, default=None)
855 855
856 856 user = relationship('User')
857 857 permission = relationship('Permission')
858 858 repository = relationship('Repository')
859 859
860 860 class UserToPerm(Base, BaseModel):
861 861 __tablename__ = 'user_to_perm'
862 862 __table_args__ = (UniqueConstraint('user_id', 'permission_id'), {'extend_existing':True})
863 863 user_to_perm_id = Column("user_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
864 864 user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
865 865 permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
866 866
867 867 user = relationship('User')
868 868 permission = relationship('Permission')
869 869
870 870 @classmethod
871 871 def has_perm(cls, user_id, perm):
872 872 if not isinstance(perm, Permission):
873 873 raise Exception('perm needs to be an instance of Permission class')
874 874
875 875 return cls.query().filter(cls.user_id == user_id)\
876 876 .filter(cls.permission == perm).scalar() is not None
877 877
878 878 @classmethod
879 879 def grant_perm(cls, user_id, perm):
880 880 if not isinstance(perm, Permission):
881 881 raise Exception('perm needs to be an instance of Permission class')
882 882
883 883 new = cls()
884 884 new.user_id = user_id
885 885 new.permission = perm
886 886 try:
887 887 Session.add(new)
888 888 Session.commit()
889 889 except:
890 890 Session.rollback()
891 891
892 892
893 893 @classmethod
894 894 def revoke_perm(cls, user_id, perm):
895 895 if not isinstance(perm, Permission):
896 896 raise Exception('perm needs to be an instance of Permission class')
897 897
898 898 try:
899 899 cls.query().filter(cls.user_id == user_id)\
900 900 .filter(cls.permission == perm).delete()
901 901 Session.commit()
902 902 except:
903 903 Session.rollback()
904 904
905 905 class UsersGroupRepoToPerm(Base, BaseModel):
906 906 __tablename__ = 'users_group_repo_to_perm'
907 907 __table_args__ = (UniqueConstraint('repository_id', 'users_group_id', 'permission_id'), {'extend_existing':True})
908 908 users_group_to_perm_id = Column("users_group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
909 909 users_group_id = Column("users_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
910 910 permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
911 911 repository_id = Column("repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=False, unique=None, default=None)
912 912
913 913 users_group = relationship('UsersGroup')
914 914 permission = relationship('Permission')
915 915 repository = relationship('Repository')
916 916
917 917 def __repr__(self):
918 918 return '<userGroup:%s => %s >' % (self.users_group, self.repository)
919 919
920 920 class UsersGroupToPerm(Base, BaseModel):
921 921 __tablename__ = 'users_group_to_perm'
922 922 users_group_to_perm_id = Column("users_group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
923 923 users_group_id = Column("users_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
924 924 permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
925 925
926 926 users_group = relationship('UsersGroup')
927 927 permission = relationship('Permission')
928 928
929 929
930 930 @classmethod
931 931 def has_perm(cls, users_group_id, perm):
932 932 if not isinstance(perm, Permission):
933 933 raise Exception('perm needs to be an instance of Permission class')
934 934
935 935 return cls.query().filter(cls.users_group_id ==
936 936 users_group_id)\
937 937 .filter(cls.permission == perm)\
938 938 .scalar() is not None
939 939
940 940 @classmethod
941 941 def grant_perm(cls, users_group_id, perm):
942 942 if not isinstance(perm, Permission):
943 943 raise Exception('perm needs to be an instance of Permission class')
944 944
945 945 new = cls()
946 946 new.users_group_id = users_group_id
947 947 new.permission = perm
948 948 try:
949 949 Session.add(new)
950 950 Session.commit()
951 951 except:
952 952 Session.rollback()
953 953
954 954
955 955 @classmethod
956 956 def revoke_perm(cls, users_group_id, perm):
957 957 if not isinstance(perm, Permission):
958 958 raise Exception('perm needs to be an instance of Permission class')
959 959
960 960 try:
961 961 cls.query().filter(cls.users_group_id == users_group_id)\
962 962 .filter(cls.permission == perm).delete()
963 963 Session.commit()
964 964 except:
965 965 Session.rollback()
966 966
967 967
968 968 class UserRepoGroupToPerm(Base, BaseModel):
969 969 __tablename__ = 'group_to_perm'
970 970 __table_args__ = (UniqueConstraint('group_id', 'permission_id'), {'extend_existing':True})
971 971
972 972 group_to_perm_id = Column("group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
973 973 user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
974 974 permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
975 975 group_id = Column("group_id", Integer(), ForeignKey('groups.group_id'), nullable=False, unique=None, default=None)
976 976
977 977 user = relationship('User')
978 978 permission = relationship('Permission')
979 979 group = relationship('RepoGroup')
980 980
981 981 class UsersGroupRepoGroupToPerm(Base, BaseModel):
982 982 __tablename__ = 'users_group_repo_group_to_perm'
983 983 __table_args__ = (UniqueConstraint('group_id', 'permission_id'), {'extend_existing':True})
984 984
985 985 users_group_repo_group_to_perm_id = Column("users_group_repo_group_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
986 986 users_group_id = Column("users_group_id", Integer(), ForeignKey('users_groups.users_group_id'), nullable=False, unique=None, default=None)
987 987 permission_id = Column("permission_id", Integer(), ForeignKey('permissions.permission_id'), nullable=False, unique=None, default=None)
988 988 group_id = Column("group_id", Integer(), ForeignKey('groups.group_id'), nullable=False, unique=None, default=None)
989 989
990 990 users_group = relationship('UsersGroup')
991 991 permission = relationship('Permission')
992 992 group = relationship('RepoGroup')
993 993
994 994 class Statistics(Base, BaseModel):
995 995 __tablename__ = 'statistics'
996 996 __table_args__ = (UniqueConstraint('repository_id'), {'extend_existing':True})
997 997 stat_id = Column("stat_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
998 998 repository_id = Column("repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=False, unique=True, default=None)
999 999 stat_on_revision = Column("stat_on_revision", Integer(), nullable=False)
1000 1000 commit_activity = Column("commit_activity", LargeBinary(1000000), nullable=False)#JSON data
1001 1001 commit_activity_combined = Column("commit_activity_combined", LargeBinary(), nullable=False)#JSON data
1002 1002 languages = Column("languages", LargeBinary(1000000), nullable=False)#JSON data
1003 1003
1004 1004 repository = relationship('Repository', single_parent=True)
1005 1005
1006 1006 class UserFollowing(Base, BaseModel):
1007 1007 __tablename__ = 'user_followings'
1008 1008 __table_args__ = (UniqueConstraint('user_id', 'follows_repository_id'),
1009 1009 UniqueConstraint('user_id', 'follows_user_id')
1010 1010 , {'extend_existing':True})
1011 1011
1012 1012 user_following_id = Column("user_following_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
1013 1013 user_id = Column("user_id", Integer(), ForeignKey('users.user_id'), nullable=False, unique=None, default=None)
1014 1014 follows_repo_id = Column("follows_repository_id", Integer(), ForeignKey('repositories.repo_id'), nullable=True, unique=None, default=None)
1015 1015 follows_user_id = Column("follows_user_id", Integer(), ForeignKey('users.user_id'), nullable=True, unique=None, default=None)
1016 1016 follows_from = Column('follows_from', DateTime(timezone=False), nullable=True, unique=None, default=datetime.datetime.now)
1017 1017
1018 1018 user = relationship('User', primaryjoin='User.user_id==UserFollowing.user_id')
1019 1019
1020 1020 follows_user = relationship('User', primaryjoin='User.user_id==UserFollowing.follows_user_id')
1021 1021 follows_repository = relationship('Repository', order_by='Repository.repo_name')
1022 1022
1023 1023
1024 1024 @classmethod
1025 1025 def get_repo_followers(cls, repo_id):
1026 1026 return cls.query().filter(cls.follows_repo_id == repo_id)
1027 1027
1028 1028 class CacheInvalidation(Base, BaseModel):
1029 1029 __tablename__ = 'cache_invalidation'
1030 1030 __table_args__ = (UniqueConstraint('cache_key'), {'extend_existing':True})
1031 1031 cache_id = Column("cache_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
1032 1032 cache_key = Column("cache_key", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
1033 1033 cache_args = Column("cache_args", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
1034 1034 cache_active = Column("cache_active", Boolean(), nullable=True, unique=None, default=False)
1035 1035
1036 1036
1037 1037 def __init__(self, cache_key, cache_args=''):
1038 1038 self.cache_key = cache_key
1039 1039 self.cache_args = cache_args
1040 1040 self.cache_active = False
1041 1041
1042 1042 def __repr__(self):
1043 1043 return "<%s('%s:%s')>" % (self.__class__.__name__,
1044 1044 self.cache_id, self.cache_key)
1045 1045
1046 1046 @classmethod
1047 1047 def invalidate(cls, key):
1048 1048 """
1049 1049 Returns Invalidation object if this given key should be invalidated
1050 1050 None otherwise. `cache_active = False` means that this cache
1051 1051 state is not valid and needs to be invalidated
1052 1052
1053 1053 :param key:
1054 1054 """
1055 1055 return cls.query()\
1056 1056 .filter(CacheInvalidation.cache_key == key)\
1057 1057 .filter(CacheInvalidation.cache_active == False)\
1058 1058 .scalar()
1059 1059
1060 1060 @classmethod
1061 1061 def set_invalidate(cls, key):
1062 1062 """
1063 1063 Mark this Cache key for invalidation
1064 1064
1065 1065 :param key:
1066 1066 """
1067 1067
1068 1068 log.debug('marking %s for invalidation' % key)
1069 1069 inv_obj = Session().query(cls)\
1070 1070 .filter(cls.cache_key == key).scalar()
1071 1071 if inv_obj:
1072 1072 inv_obj.cache_active = False
1073 1073 else:
1074 1074 log.debug('cache key not found in invalidation db -> creating one')
1075 1075 inv_obj = CacheInvalidation(key)
1076 1076
1077 1077 try:
1078 1078 Session.add(inv_obj)
1079 1079 Session.commit()
1080 1080 except Exception:
1081 1081 log.error(traceback.format_exc())
1082 1082 Session.rollback()
1083 1083
1084 1084 @classmethod
1085 1085 def set_valid(cls, key):
1086 1086 """
1087 1087 Mark this cache key as active and currently cached
1088 1088
1089 1089 :param key:
1090 1090 """
1091 1091 inv_obj = Session().query(CacheInvalidation)\
1092 1092 .filter(CacheInvalidation.cache_key == key).scalar()
1093 1093 inv_obj.cache_active = True
1094 1094 Session.add(inv_obj)
1095 1095 Session.commit()
1096 1096
1097 1097
1098 1098 class ChangesetComment(Base, BaseModel):
1099 1099 __tablename__ = 'changeset_comments'
1100 1100 __table_args__ = ({'extend_existing':True},)
1101 1101 comment_id = Column('comment_id', Integer(), nullable=False, primary_key=True)
1102 1102 repo_id = Column('repo_id', Integer(), ForeignKey('repositories.repo_id'), nullable=False)
1103 1103 revision = Column('revision', String(40), nullable=False)
1104 1104 line_no = Column('line_no', Unicode(10), nullable=True)
1105 1105 f_path = Column('f_path', Unicode(1000), nullable=True)
1106 1106 user_id = Column('user_id', Integer(), ForeignKey('users.user_id'), nullable=False)
1107 1107 text = Column('text', Unicode(25000), nullable=False)
1108 1108 modified_at = Column('modified_at', DateTime(), nullable=False, default=datetime.datetime.now)
1109 1109
1110 1110 author = relationship('User')
1111 1111 repo = relationship('Repository')
1112 1112
1113 1113
1114 1114 class DbMigrateVersion(Base, BaseModel):
1115 1115 __tablename__ = 'db_migrate_version'
1116 1116 __table_args__ = {'extend_existing':True}
1117 1117 repository_id = Column('repository_id', String(250), primary_key=True)
1118 1118 repository_path = Column('repository_path', Text)
1119 1119 version = Column('version', Integer)
1120 1120
@@ -1,480 +1,474 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 rhodecode.model.user
4 4 ~~~~~~~~~~~~~~~~~~~~
5 5
6 6 users model for RhodeCode
7 7
8 8 :created_on: Apr 9, 2010
9 9 :author: marcink
10 10 :copyright: (C) 2009-2011 Marcin Kuzminski <marcin@python-works.com>
11 11 :license: GPLv3, see COPYING for more details.
12 12 """
13 13 # This program is free software: you can redistribute it and/or modify
14 14 # it under the terms of the GNU General Public License as published by
15 15 # the Free Software Foundation, either version 3 of the License, or
16 16 # (at your option) any later version.
17 17 #
18 18 # This program is distributed in the hope that it will be useful,
19 19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 21 # GNU General Public License for more details.
22 22 #
23 23 # You should have received a copy of the GNU General Public License
24 24 # along with this program. If not, see <http://www.gnu.org/licenses/>.
25 25
26 26 import logging
27 27 import traceback
28 28
29 29 from pylons.i18n.translation import _
30 30
31 31 from rhodecode.lib import safe_unicode
32 32 from rhodecode.lib.caching_query import FromCache
33 33
34 34 from rhodecode.model import BaseModel
35 35 from rhodecode.model.db import User, UserRepoToPerm, Repository, Permission, \
36 36 UserToPerm, UsersGroupRepoToPerm, UsersGroupToPerm, UsersGroupMember
37 37 from rhodecode.lib.exceptions import DefaultUserException, \
38 38 UserOwnsReposException
39 39
40 40 from sqlalchemy.exc import DatabaseError
41 41 from rhodecode.lib import generate_api_key
42 42 from sqlalchemy.orm import joinedload
43 43
44 44 log = logging.getLogger(__name__)
45 45
46 46 PERM_WEIGHTS = {'repository.none': 0,
47 47 'repository.read': 1,
48 48 'repository.write': 3,
49 49 'repository.admin': 3}
50 50
51 51
52 52 class UserModel(BaseModel):
53 53 def get(self, user_id, cache=False):
54 54 user = self.sa.query(User)
55 55 if cache:
56 56 user = user.options(FromCache("sql_cache_short",
57 57 "get_user_%s" % user_id))
58 58 return user.get(user_id)
59 59
60 60 def get_by_username(self, username, cache=False, case_insensitive=False):
61 61
62 62 if case_insensitive:
63 63 user = self.sa.query(User).filter(User.username.ilike(username))
64 64 else:
65 65 user = self.sa.query(User)\
66 66 .filter(User.username == username)
67 67 if cache:
68 68 user = user.options(FromCache("sql_cache_short",
69 69 "get_user_%s" % username))
70 70 return user.scalar()
71 71
72 72 def get_by_api_key(self, api_key, cache=False):
73
74 user = self.sa.query(User)\
75 .filter(User.api_key == api_key)
76 if cache:
77 user = user.options(FromCache("sql_cache_short",
78 "get_user_%s" % api_key))
79 return user.scalar()
73 return User.get_by_api_key(api_key, cache)
80 74
81 75 def create(self, form_data):
82 76 try:
83 77 new_user = User()
84 78 for k, v in form_data.items():
85 79 setattr(new_user, k, v)
86 80
87 81 new_user.api_key = generate_api_key(form_data['username'])
88 82 self.sa.add(new_user)
89 83 self.sa.commit()
90 84 return new_user
91 85 except:
92 86 log.error(traceback.format_exc())
93 87 self.sa.rollback()
94 88 raise
95 89
96 90
97 91 def create_or_update(self, username, password, email, name, lastname,
98 92 active=True, admin=False, ldap_dn=None):
99 93 """
100 94 Creates a new instance if not found, or updates current one
101 95
102 96 :param username:
103 97 :param password:
104 98 :param email:
105 99 :param active:
106 100 :param name:
107 101 :param lastname:
108 102 :param active:
109 103 :param admin:
110 104 :param ldap_dn:
111 105 """
112 106
113 107 from rhodecode.lib.auth import get_crypt_password
114 108
115 109 log.debug('Checking for %s account in RhodeCode database', username)
116 110 user = User.get_by_username(username, case_insensitive=True)
117 111 if user is None:
118 112 log.debug('creating new user %s', username)
119 113 new_user = User()
120 114 else:
121 115 log.debug('updating user %s', username)
122 116 new_user = user
123 117
124 118 try:
125 119 new_user.username = username
126 120 new_user.admin = admin
127 121 new_user.password = get_crypt_password(password)
128 122 new_user.api_key = generate_api_key(username)
129 123 new_user.email = email
130 124 new_user.active = active
131 125 new_user.ldap_dn = safe_unicode(ldap_dn) if ldap_dn else None
132 126 new_user.name = name
133 127 new_user.lastname = lastname
134 128
135 129 self.sa.add(new_user)
136 130 self.sa.commit()
137 131 return new_user
138 132 except (DatabaseError,):
139 133 log.error(traceback.format_exc())
140 134 self.sa.rollback()
141 135 raise
142 136
143 137
144 138 def create_for_container_auth(self, username, attrs):
145 139 """
146 140 Creates the given user if it's not already in the database
147 141
148 142 :param username:
149 143 :param attrs:
150 144 """
151 145 if self.get_by_username(username, case_insensitive=True) is None:
152 146
153 147 # autogenerate email for container account without one
154 148 generate_email = lambda usr: '%s@container_auth.account' % usr
155 149
156 150 try:
157 151 new_user = User()
158 152 new_user.username = username
159 153 new_user.password = None
160 154 new_user.api_key = generate_api_key(username)
161 155 new_user.email = attrs['email']
162 156 new_user.active = attrs.get('active', True)
163 157 new_user.name = attrs['name'] or generate_email(username)
164 158 new_user.lastname = attrs['lastname']
165 159
166 160 self.sa.add(new_user)
167 161 self.sa.commit()
168 162 return new_user
169 163 except (DatabaseError,):
170 164 log.error(traceback.format_exc())
171 165 self.sa.rollback()
172 166 raise
173 167 log.debug('User %s already exists. Skipping creation of account'
174 168 ' for container auth.', username)
175 169 return None
176 170
177 171 def create_ldap(self, username, password, user_dn, attrs):
178 172 """
179 173 Checks if user is in database, if not creates this user marked
180 174 as ldap user
181 175
182 176 :param username:
183 177 :param password:
184 178 :param user_dn:
185 179 :param attrs:
186 180 """
187 181 from rhodecode.lib.auth import get_crypt_password
188 182 log.debug('Checking for such ldap account in RhodeCode database')
189 183 if self.get_by_username(username, case_insensitive=True) is None:
190 184
191 185 # autogenerate email for ldap account without one
192 186 generate_email = lambda usr: '%s@ldap.account' % usr
193 187
194 188 try:
195 189 new_user = User()
196 190 username = username.lower()
197 191 # add ldap account always lowercase
198 192 new_user.username = username
199 193 new_user.password = get_crypt_password(password)
200 194 new_user.api_key = generate_api_key(username)
201 195 new_user.email = attrs['email'] or generate_email(username)
202 196 new_user.active = attrs.get('active', True)
203 197 new_user.ldap_dn = safe_unicode(user_dn)
204 198 new_user.name = attrs['name']
205 199 new_user.lastname = attrs['lastname']
206 200
207 201 self.sa.add(new_user)
208 202 self.sa.commit()
209 203 return new_user
210 204 except (DatabaseError,):
211 205 log.error(traceback.format_exc())
212 206 self.sa.rollback()
213 207 raise
214 208 log.debug('this %s user exists skipping creation of ldap account',
215 209 username)
216 210 return None
217 211
218 212 def create_registration(self, form_data):
219 213 from rhodecode.lib.celerylib import tasks, run_task
220 214 try:
221 215 new_user = User()
222 216 for k, v in form_data.items():
223 217 if k != 'admin':
224 218 setattr(new_user, k, v)
225 219
226 220 self.sa.add(new_user)
227 221 self.sa.commit()
228 222 body = ('New user registration\n'
229 223 'username: %s\n'
230 224 'email: %s\n')
231 225 body = body % (form_data['username'], form_data['email'])
232 226
233 227 run_task(tasks.send_email, None,
234 228 _('[RhodeCode] New User registration'),
235 229 body)
236 230 except:
237 231 log.error(traceback.format_exc())
238 232 self.sa.rollback()
239 233 raise
240 234
241 235 def update(self, user_id, form_data):
242 236 try:
243 237 user = self.get(user_id, cache=False)
244 238 if user.username == 'default':
245 239 raise DefaultUserException(
246 240 _("You can't Edit this user since it's"
247 241 " crucial for entire application"))
248 242
249 243 for k, v in form_data.items():
250 244 if k == 'new_password' and v != '':
251 245 user.password = v
252 246 user.api_key = generate_api_key(user.username)
253 247 else:
254 248 setattr(user, k, v)
255 249
256 250 self.sa.add(user)
257 251 self.sa.commit()
258 252 except:
259 253 log.error(traceback.format_exc())
260 254 self.sa.rollback()
261 255 raise
262 256
263 257 def update_my_account(self, user_id, form_data):
264 258 try:
265 259 user = self.get(user_id, cache=False)
266 260 if user.username == 'default':
267 261 raise DefaultUserException(
268 262 _("You can't Edit this user since it's"
269 263 " crucial for entire application"))
270 264 for k, v in form_data.items():
271 265 if k == 'new_password' and v != '':
272 266 user.password = v
273 267 user.api_key = generate_api_key(user.username)
274 268 else:
275 269 if k not in ['admin', 'active']:
276 270 setattr(user, k, v)
277 271
278 272 self.sa.add(user)
279 273 self.sa.commit()
280 274 except:
281 275 log.error(traceback.format_exc())
282 276 self.sa.rollback()
283 277 raise
284 278
285 279 def delete(self, user_id):
286 280 try:
287 281 user = self.get(user_id, cache=False)
288 282 if user.username == 'default':
289 283 raise DefaultUserException(
290 284 _("You can't remove this user since it's"
291 285 " crucial for entire application"))
292 286 if user.repositories:
293 287 raise UserOwnsReposException(_('This user still owns %s '
294 288 'repositories and cannot be '
295 289 'removed. Switch owners or '
296 290 'remove those repositories') \
297 291 % user.repositories)
298 292 self.sa.delete(user)
299 293 self.sa.commit()
300 294 except:
301 295 log.error(traceback.format_exc())
302 296 self.sa.rollback()
303 297 raise
304 298
305 299 def reset_password_link(self, data):
306 300 from rhodecode.lib.celerylib import tasks, run_task
307 301 run_task(tasks.send_password_link, data['email'])
308 302
309 303 def reset_password(self, data):
310 304 from rhodecode.lib.celerylib import tasks, run_task
311 305 run_task(tasks.reset_user_password, data['email'])
312 306
313 307 def fill_data(self, auth_user, user_id=None, api_key=None):
314 308 """
315 309 Fetches auth_user by user_id,or api_key if present.
316 310 Fills auth_user attributes with those taken from database.
317 311 Additionally set's is_authenitated if lookup fails
318 312 present in database
319 313
320 314 :param auth_user: instance of user to set attributes
321 315 :param user_id: user id to fetch by
322 316 :param api_key: api key to fetch by
323 317 """
324 318 if user_id is None and api_key is None:
325 319 raise Exception('You need to pass user_id or api_key')
326 320
327 321 try:
328 322 if api_key:
329 323 dbuser = self.get_by_api_key(api_key)
330 324 else:
331 325 dbuser = self.get(user_id)
332 326
333 327 if dbuser is not None and dbuser.active:
334 328 log.debug('filling %s data', dbuser)
335 329 for k, v in dbuser.get_dict().items():
336 330 setattr(auth_user, k, v)
337 331 else:
338 332 return False
339 333
340 334 except:
341 335 log.error(traceback.format_exc())
342 336 auth_user.is_authenticated = False
343 337 return False
344 338
345 339 return True
346 340
347 341 def fill_perms(self, user):
348 342 """
349 343 Fills user permission attribute with permissions taken from database
350 344 works for permissions given for repositories, and for permissions that
351 345 are granted to groups
352 346
353 347 :param user: user instance to fill his perms
354 348 """
355 349
356 350 user.permissions['repositories'] = {}
357 351 user.permissions['global'] = set()
358 352
359 353 #======================================================================
360 354 # fetch default permissions
361 355 #======================================================================
362 356 default_user = self.get_by_username('default', cache=True)
363 357
364 358 default_perms = self.sa.query(UserRepoToPerm, Repository, Permission)\
365 359 .join((Repository, UserRepoToPerm.repository_id ==
366 360 Repository.repo_id))\
367 361 .join((Permission, UserRepoToPerm.permission_id ==
368 362 Permission.permission_id))\
369 363 .filter(UserRepoToPerm.user == default_user).all()
370 364
371 365 if user.is_admin:
372 366 #==================================================================
373 367 # #admin have all default rights set to admin
374 368 #==================================================================
375 369 user.permissions['global'].add('hg.admin')
376 370
377 371 for perm in default_perms:
378 372 p = 'repository.admin'
379 373 user.permissions['repositories'][perm.UserRepoToPerm.
380 374 repository.repo_name] = p
381 375
382 376 else:
383 377 #==================================================================
384 378 # set default permissions
385 379 #==================================================================
386 380 uid = user.user_id
387 381
388 382 #default global
389 383 default_global_perms = self.sa.query(UserToPerm)\
390 384 .filter(UserToPerm.user == default_user)
391 385
392 386 for perm in default_global_perms:
393 387 user.permissions['global'].add(perm.permission.permission_name)
394 388
395 389 #default for repositories
396 390 for perm in default_perms:
397 391 if perm.Repository.private and not (perm.Repository.user_id ==
398 392 uid):
399 393 #diself.sable defaults for private repos,
400 394 p = 'repository.none'
401 395 elif perm.Repository.user_id == uid:
402 396 #set admin if owner
403 397 p = 'repository.admin'
404 398 else:
405 399 p = perm.Permission.permission_name
406 400
407 401 user.permissions['repositories'][perm.UserRepoToPerm.
408 402 repository.repo_name] = p
409 403
410 404 #==================================================================
411 405 # overwrite default with user permissions if any
412 406 #==================================================================
413 407
414 408 #user global
415 409 user_perms = self.sa.query(UserToPerm)\
416 410 .options(joinedload(UserToPerm.permission))\
417 411 .filter(UserToPerm.user_id == uid).all()
418 412
419 413 for perm in user_perms:
420 414 user.permissions['global'].add(perm.permission.
421 415 permission_name)
422 416
423 417 #user repositories
424 418 user_repo_perms = self.sa.query(UserRepoToPerm, Permission,
425 419 Repository)\
426 420 .join((Repository, UserRepoToPerm.repository_id ==
427 421 Repository.repo_id))\
428 422 .join((Permission, UserRepoToPerm.permission_id ==
429 423 Permission.permission_id))\
430 424 .filter(UserRepoToPerm.user_id == uid).all()
431 425
432 426 for perm in user_repo_perms:
433 427 # set admin if owner
434 428 if perm.Repository.user_id == uid:
435 429 p = 'repository.admin'
436 430 else:
437 431 p = perm.Permission.permission_name
438 432 user.permissions['repositories'][perm.UserRepoToPerm.
439 433 repository.repo_name] = p
440 434
441 435 #==================================================================
442 436 # check if user is part of groups for this repository and fill in
443 437 # (or replace with higher) permissions
444 438 #==================================================================
445 439
446 440 #users group global
447 441 user_perms_from_users_groups = self.sa.query(UsersGroupToPerm)\
448 442 .options(joinedload(UsersGroupToPerm.permission))\
449 443 .join((UsersGroupMember, UsersGroupToPerm.users_group_id ==
450 444 UsersGroupMember.users_group_id))\
451 445 .filter(UsersGroupMember.user_id == uid).all()
452 446
453 447 for perm in user_perms_from_users_groups:
454 448 user.permissions['global'].add(perm.permission.permission_name)
455 449
456 450 #users group repositories
457 451 user_repo_perms_from_users_groups = self.sa.query(
458 452 UsersGroupRepoToPerm,
459 453 Permission, Repository,)\
460 454 .join((Repository, UsersGroupRepoToPerm.repository_id ==
461 455 Repository.repo_id))\
462 456 .join((Permission, UsersGroupRepoToPerm.permission_id ==
463 457 Permission.permission_id))\
464 458 .join((UsersGroupMember, UsersGroupRepoToPerm.users_group_id ==
465 459 UsersGroupMember.users_group_id))\
466 460 .filter(UsersGroupMember.user_id == uid).all()
467 461
468 462 for perm in user_repo_perms_from_users_groups:
469 463 p = perm.Permission.permission_name
470 464 cur_perm = user.permissions['repositories'][perm.
471 465 UsersGroupRepoToPerm.
472 466 repository.repo_name]
473 467 #overwrite permission only if it's greater than permission
474 468 # given from other sources
475 469 if PERM_WEIGHTS[p] > PERM_WEIGHTS[cur_perm]:
476 470 user.permissions['repositories'][perm.UsersGroupRepoToPerm.
477 471 repository.repo_name] = p
478 472
479 473 return user
480 474
@@ -1,261 +1,260 b''
1 1 # -*- coding: utf-8 -*-
2 2 from rhodecode.tests import *
3 3 from rhodecode.model.db import User
4 4 from rhodecode.lib import generate_api_key
5 5 from rhodecode.lib.auth import check_password
6 6
7 7
8 8 class TestLoginController(TestController):
9 9
10 10 def test_index(self):
11 11 response = self.app.get(url(controller='login', action='index'))
12 12 self.assertEqual(response.status, '200 OK')
13 13 # Test response...
14 14
15 15 def test_login_admin_ok(self):
16 16 response = self.app.post(url(controller='login', action='index'),
17 17 {'username':'test_admin',
18 18 'password':'test12'})
19 19 self.assertEqual(response.status, '302 Found')
20 20 self.assertEqual(response.session['rhodecode_user'].username ,
21 21 'test_admin')
22 22 response = response.follow()
23 23 self.assertTrue('%s repository' % HG_REPO in response.body)
24 24
25 25 def test_login_regular_ok(self):
26 26 response = self.app.post(url(controller='login', action='index'),
27 27 {'username':'test_regular',
28 28 'password':'test12'})
29 29
30 30 self.assertEqual(response.status, '302 Found')
31 31 self.assertEqual(response.session['rhodecode_user'].username ,
32 32 'test_regular')
33 33 response = response.follow()
34 34 self.assertTrue('%s repository' % HG_REPO in response.body)
35 35 self.assertTrue('<a title="Admin" href="/_admin">' not in response.body)
36 36
37 37 def test_login_ok_came_from(self):
38 38 test_came_from = '/_admin/users'
39 39 response = self.app.post(url(controller='login', action='index',
40 40 came_from=test_came_from),
41 41 {'username':'test_admin',
42 42 'password':'test12'})
43 43 self.assertEqual(response.status, '302 Found')
44 44 response = response.follow()
45 45
46 46 self.assertEqual(response.status, '200 OK')
47 47 self.assertTrue('Users administration' in response.body)
48 48
49 49
50 50 def test_login_short_password(self):
51 51 response = self.app.post(url(controller='login', action='index'),
52 52 {'username':'test_admin',
53 53 'password':'as'})
54 54 self.assertEqual(response.status, '200 OK')
55 55
56 56 self.assertTrue('Enter 3 characters or more' in response.body)
57 57
58 58 def test_login_wrong_username_password(self):
59 59 response = self.app.post(url(controller='login', action='index'),
60 60 {'username':'error',
61 61 'password':'test12'})
62 62 self.assertEqual(response.status , '200 OK')
63 63
64 64 self.assertTrue('invalid user name' in response.body)
65 65 self.assertTrue('invalid password' in response.body)
66 66
67 67 #==========================================================================
68 68 # REGISTRATIONS
69 69 #==========================================================================
70 70 def test_register(self):
71 71 response = self.app.get(url(controller='login', action='register'))
72 72 self.assertTrue('Sign Up to RhodeCode' in response.body)
73 73
74 74 def test_register_err_same_username(self):
75 75 response = self.app.post(url(controller='login', action='register'),
76 76 {'username':'test_admin',
77 77 'password':'test12',
78 78 'password_confirmation':'test12',
79 79 'email':'goodmail@domain.com',
80 80 'name':'test',
81 81 'lastname':'test'})
82 82
83 83 self.assertEqual(response.status , '200 OK')
84 84 self.assertTrue('This username already exists' in response.body)
85 85
86 86 def test_register_err_same_email(self):
87 87 response = self.app.post(url(controller='login', action='register'),
88 88 {'username':'test_admin_0',
89 89 'password':'test12',
90 90 'password_confirmation':'test12',
91 91 'email':'test_admin@mail.com',
92 92 'name':'test',
93 93 'lastname':'test'})
94 94
95 95 self.assertEqual(response.status , '200 OK')
96 96 assert 'This e-mail address is already taken' in response.body
97 97
98 98 def test_register_err_same_email_case_sensitive(self):
99 99 response = self.app.post(url(controller='login', action='register'),
100 100 {'username':'test_admin_1',
101 101 'password':'test12',
102 102 'password_confirmation':'test12',
103 103 'email':'TesT_Admin@mail.COM',
104 104 'name':'test',
105 105 'lastname':'test'})
106 106 self.assertEqual(response.status , '200 OK')
107 107 assert 'This e-mail address is already taken' in response.body
108 108
109 109 def test_register_err_wrong_data(self):
110 110 response = self.app.post(url(controller='login', action='register'),
111 111 {'username':'xs',
112 112 'password':'test',
113 113 'password_confirmation':'test',
114 114 'email':'goodmailm',
115 115 'name':'test',
116 116 'lastname':'test'})
117 117 self.assertEqual(response.status , '200 OK')
118 118 assert 'An email address must contain a single @' in response.body
119 119 assert 'Enter a value 6 characters long or more' in response.body
120 120
121 121
122 122 def test_register_err_username(self):
123 123 response = self.app.post(url(controller='login', action='register'),
124 124 {'username':'error user',
125 125 'password':'test12',
126 126 'password_confirmation':'test12',
127 127 'email':'goodmailm',
128 128 'name':'test',
129 129 'lastname':'test'})
130 130
131 131 self.assertEqual(response.status , '200 OK')
132 132 assert 'An email address must contain a single @' in response.body
133 133 assert ('Username may only contain '
134 134 'alphanumeric characters underscores, '
135 135 'periods or dashes and must begin with '
136 136 'alphanumeric character') in response.body
137 137
138 138 def test_register_err_case_sensitive(self):
139 139 response = self.app.post(url(controller='login', action='register'),
140 140 {'username':'Test_Admin',
141 141 'password':'test12',
142 142 'password_confirmation':'test12',
143 143 'email':'goodmailm',
144 144 'name':'test',
145 145 'lastname':'test'})
146 146
147 147 self.assertEqual(response.status , '200 OK')
148 148 self.assertTrue('An email address must contain a single @' in response.body)
149 149 self.assertTrue('This username already exists' in response.body)
150 150
151 151
152 152
153 153 def test_register_special_chars(self):
154 154 response = self.app.post(url(controller='login', action='register'),
155 155 {'username':'xxxaxn',
156 156 'password':'Δ…Δ‡ΕΊΕΌΔ…Ε›Ε›Ε›Ε›',
157 157 'password_confirmation':'Δ…Δ‡ΕΊΕΌΔ…Ε›Ε›Ε›Ε›',
158 158 'email':'goodmailm@test.plx',
159 159 'name':'test',
160 160 'lastname':'test'})
161 161
162 162 self.assertEqual(response.status , '200 OK')
163 163 self.assertTrue('Invalid characters in password' in response.body)
164 164
165 165
166 166 def test_register_password_mismatch(self):
167 167 response = self.app.post(url(controller='login', action='register'),
168 168 {'username':'xs',
169 169 'password':'123qwe',
170 170 'password_confirmation':'qwe123',
171 171 'email':'goodmailm@test.plxa',
172 172 'name':'test',
173 173 'lastname':'test'})
174 174
175 175 self.assertEqual(response.status , '200 OK')
176 176 assert 'Passwords do not match' in response.body
177 177
178 178 def test_register_ok(self):
179 179 username = 'test_regular4'
180 180 password = 'qweqwe'
181 181 email = 'marcin@test.com'
182 182 name = 'testname'
183 183 lastname = 'testlastname'
184 184
185 185 response = self.app.post(url(controller='login', action='register'),
186 186 {'username':username,
187 187 'password':password,
188 188 'password_confirmation':password,
189 189 'email':email,
190 190 'name':name,
191 191 'lastname':lastname})
192 192 self.assertEqual(response.status , '302 Found')
193 193 assert 'You have successfully registered into rhodecode' in response.session['flash'][0], 'No flash message about user registration'
194 194
195 195 ret = self.sa.query(User).filter(User.username == 'test_regular4').one()
196 196 assert ret.username == username , 'field mismatch %s %s' % (ret.username, username)
197 197 assert check_password(password, ret.password) == True , 'password mismatch'
198 198 assert ret.email == email , 'field mismatch %s %s' % (ret.email, email)
199 199 assert ret.name == name , 'field mismatch %s %s' % (ret.name, name)
200 200 assert ret.lastname == lastname , 'field mismatch %s %s' % (ret.lastname, lastname)
201 201
202 202
203 203 def test_forgot_password_wrong_mail(self):
204 204 response = self.app.post(url(controller='login', action='password_reset'),
205 205 {'email':'marcin@wrongmail.org', })
206 206
207 207 assert "This e-mail address doesn't exist" in response.body, 'Missing error message about wrong email'
208 208
209 209 def test_forgot_password(self):
210 210 response = self.app.get(url(controller='login',
211 211 action='password_reset'))
212 212 self.assertEqual(response.status , '200 OK')
213 213
214 214 username = 'test_password_reset_1'
215 215 password = 'qweqwe'
216 216 email = 'marcin@python-works.com'
217 217 name = 'passwd'
218 218 lastname = 'reset'
219 219
220 220 new = User()
221 221 new.username = username
222 222 new.password = password
223 223 new.email = email
224 224 new.name = name
225 225 new.lastname = lastname
226 226 new.api_key = generate_api_key(username)
227 227 self.sa.add(new)
228 228 self.sa.commit()
229 229
230 230 response = self.app.post(url(controller='login',
231 231 action='password_reset'),
232 232 {'email':email, })
233 233
234 234 self.checkSessionFlash(response, 'Your password reset link was sent')
235 235
236 236 response = response.follow()
237 237
238 238 # BAD KEY
239 239
240 240 key = "bad"
241 241 response = self.app.get(url(controller='login',
242 242 action='password_reset_confirmation',
243 243 key=key))
244 244 self.assertEqual(response.status, '302 Found')
245 245 self.assertTrue(response.location.endswith(url('reset_password')))
246 246
247 247 # GOOD KEY
248 248
249 249 key = User.get_by_username(username).api_key
250
251 250 response = self.app.get(url(controller='login',
252 251 action='password_reset_confirmation',
253 252 key=key))
254 253 self.assertEqual(response.status, '302 Found')
255 254 self.assertTrue(response.location.endswith(url('login_home')))
256 255
257 256 self.checkSessionFlash(response,
258 257 ('Your password reset was successful, '
259 258 'new password has been sent to your email'))
260 259
261 260 response = response.follow()
General Comments 0
You need to be logged in to leave comments. Login now