diff --git a/rhodecode/model/db.py b/rhodecode/model/db.py --- a/rhodecode/model/db.py +++ b/rhodecode/model/db.py @@ -35,7 +35,7 @@ import collections import pyotp from sqlalchemy import ( - or_, and_, not_, func, cast, TypeDecorator, event, select, + or_, and_, not_, func, cast, TypeDecorator, event, select, delete, true, false, null, union_all, Index, Sequence, UniqueConstraint, ForeignKey, CheckConstraint, Column, Boolean, String, Unicode, UnicodeText, DateTime, Integer, LargeBinary, @@ -274,6 +274,22 @@ class BaseModel(object): return stmt @classmethod + def delete(cls, custom_cls=None): + """ + stmt = cls.delete().where(cls.user_id==1) + # optionally + stmt = cls.delete(User).where(cls.user_id==1) + result = cls.execute(stmt) + """ + + if custom_cls: + stmt = delete(custom_cls) + else: + stmt = delete(cls) + return stmt + + + @classmethod def execute(cls, stmt): return Session().execute(stmt) @@ -1075,28 +1091,26 @@ class User(Base, BaseModel): @classmethod def get(cls, user_id, cache=False): if not user_id: - return + return None q = cls.select().where(cls.user_id == user_id) if cache: - q = q.options( - FromCache("sql_cache_short", f"get_users_{user_id}")) + q = q.options(FromCache("sql_cache_short", f"get_users_{user_id}")) return cls.execute(q).scalar_one_or_none() @classmethod - def get_by_username(cls, username, case_insensitive=False, - cache=False): + def get_by_username(cls, username, case_insensitive=False, cache=False): + if not username: + return None if case_insensitive: - q = cls.select().where( - func.lower(cls.username) == func.lower(username)) + q = cls.select().where(func.lower(cls.username) == func.lower(username)) else: q = cls.select().where(cls.username == username) if cache: hash_key = _hash_key(username) - q = q.options( - FromCache("sql_cache_short", f"get_user_by_name_{hash_key}")) + q = q.options(FromCache("sql_cache_short", f"get_user_by_name_{hash_key}")) return cls.execute(q).scalar_one_or_none() @@ -1125,6 +1139,8 @@ class User(Base, BaseModel): @classmethod def get_by_email(cls, email, case_insensitive=False, cache=False): + if not email: + return None if case_insensitive: q = cls.select().where(func.lower(cls.email) == func.lower(email)) @@ -1133,8 +1149,7 @@ class User(Base, BaseModel): if cache: email_key = _hash_key(email) - q = q.options( - FromCache("sql_cache_short", f"get_email_key_{email_key}")) + q = q.options(FromCache("sql_cache_short", f"get_email_key_{email_key}")) ret = cls.execute(q).scalar_one_or_none() @@ -1147,8 +1162,8 @@ class User(Base, BaseModel): q = q.where(UserEmailMap.email == email) q = q.options(joinedload(UserEmailMap.user)) if cache: - q = q.options( - FromCache("sql_cache_short", f"get_email_map_key_{email_key}")) + email_key = _hash_key(email) + q = q.options(FromCache("sql_cache_short", f"get_email_map_key_{email_key}")) result = cls.execute(q).scalar_one_or_none() ret = getattr(result, 'user', None) @@ -1642,11 +1657,12 @@ class UserGroup(Base, BaseModel): return f"<{self.cls_name}('id:{self.users_group_id}:{self.users_group_name}')>" @classmethod - def get_by_group_name(cls, group_name, cache=False, - case_insensitive=False): + def get_by_group_name(cls, group_name, cache=False, case_insensitive=False): + if not group_name: + return None + if case_insensitive: - q = cls.query().filter(func.lower(cls.users_group_name) == - func.lower(group_name)) + q = cls.query().filter(func.lower(cls.users_group_name) == func.lower(group_name)) else: q = cls.query().filter(cls.users_group_name == group_name) @@ -2968,9 +2984,11 @@ class RepoGroup(Base, BaseModel): @classmethod def get_by_group_name(cls, group_name, cache=False, case_insensitive=False): + if not group_name: + return None + if case_insensitive: - gr = cls.query().filter(func.lower(cls.group_name) - == func.lower(group_name)) + gr = cls.query().filter(func.lower(cls.group_name) == func.lower(group_name)) else: gr = cls.query().filter(cls.group_name == group_name) if cache: