# HG changeset patch # User Marcin Kuzminski # Date 2010-12-11 00:54:12 # Node ID 9753e0907827670aebf95083e20e5261f2b3009c # Parent 634596f81cfdd5cf9dd3755bad2ac451e2bd2b19 added dbmigrate package, added model changes moved out upgrade db command to that package diff --git a/rhodecode/lib/dbmigrate/__init__.py b/rhodecode/lib/dbmigrate/__init__.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/__init__.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +""" + rhodecode.lib.dbmigrate.__init__ + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Database migration modules + + :created_on: Dec 11, 2010 + :author: marcink + :copyright: (C) 2009-2010 Marcin Kuzminski + :license: GPLv3, see COPYING for more details. +""" +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; version 2 +# of the License or (at your opinion) any later version of the license. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, +# MA 02110-1301, USA. + +from rhodecode.lib.utils import BasePasterCommand +from rhodecode.lib.utils import BasePasterCommand, Command, add_cache + +from sqlalchemy import engine_from_config + +class UpgradeDb(BasePasterCommand): + """Command used for paster to upgrade our database to newer version + """ + + max_args = 1 + min_args = 1 + + usage = "CONFIG_FILE" + summary = "Upgrades current db to newer version given configuration file" + group_name = "RhodeCode" + + parser = Command.standard_parser(verbose=True) + + def command(self): + from pylons import config + add_cache(config) + engine = engine_from_config(config, 'sqlalchemy.db1.') + print engine + raise NotImplementedError('Not implemented yet') + + + def update_parser(self): + self.parser.add_option('--sql', + action='store_true', + dest='just_sql', + help="Prints upgrade sql for further investigation", + default=False) diff --git a/rhodecode/lib/dbmigrate/migrate/__init__.py b/rhodecode/lib/dbmigrate/migrate/__init__.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/__init__.py @@ -0,0 +1,9 @@ +""" + SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for + database schema version and repository management and + :mod:`migrate.changeset` that allows to define database schema changes + using Python. +""" + +from migrate.versioning import * +from migrate.changeset import * diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/__init__.py b/rhodecode/lib/dbmigrate/migrate/changeset/__init__.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/__init__.py @@ -0,0 +1,28 @@ +""" + This module extends SQLAlchemy and provides additional DDL [#]_ + support. + + .. [#] SQL Data Definition Language +""" +import re +import warnings + +import sqlalchemy +from sqlalchemy import __version__ as _sa_version + +warnings.simplefilter('always', DeprecationWarning) + +_sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split(".")) +SQLA_06 = _sa_version >= (0, 6) + +del re +del _sa_version + +from migrate.changeset.schema import * +from migrate.changeset.constraint import * + +sqlalchemy.schema.Table.__bases__ += (ChangesetTable, ) +sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, ) +sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, ) + +sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause, ) diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/ansisql.py b/rhodecode/lib/dbmigrate/migrate/changeset/ansisql.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/ansisql.py @@ -0,0 +1,358 @@ +""" + Extensions to SQLAlchemy for altering existing tables. + + At the moment, this isn't so much based off of ANSI as much as + things that just happen to work with multiple databases. +""" +import StringIO + +import sqlalchemy as sa +from sqlalchemy.schema import SchemaVisitor +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.sql import ClauseElement +from sqlalchemy.schema import (ForeignKeyConstraint, + PrimaryKeyConstraint, + CheckConstraint, + UniqueConstraint, + Index) + +from migrate import exceptions +from migrate.changeset import constraint, SQLA_06 + +if not SQLA_06: + from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper +else: + from sqlalchemy.schema import AddConstraint, DropConstraint + from sqlalchemy.sql.compiler import DDLCompiler + SchemaGenerator = SchemaDropper = DDLCompiler + + +class AlterTableVisitor(SchemaVisitor): + """Common operations for ``ALTER TABLE`` statements.""" + + if SQLA_06: + # engine.Compiler looks for .statement + # when it spawns off a new compiler + statement = ClauseElement() + + def append(self, s): + """Append content to the SchemaIterator's query buffer.""" + + self.buffer.write(s) + + def execute(self): + """Execute the contents of the SchemaIterator's buffer.""" + try: + return self.connection.execute(self.buffer.getvalue()) + finally: + self.buffer.truncate(0) + + def __init__(self, dialect, connection, **kw): + self.connection = connection + self.buffer = StringIO.StringIO() + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def traverse_single(self, elem): + ret = super(AlterTableVisitor, self).traverse_single(elem) + if ret: + # adapt to 0.6 which uses a string-returning + # object + self.append(" %s" % ret) + + def _to_table(self, param): + """Returns the table object for the given param object.""" + if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)): + ret = param.table + else: + ret = param + return ret + + def start_alter_table(self, param): + """Returns the start of an ``ALTER TABLE`` SQL-Statement. + + Use the param object to determine the table name and use it + for building the SQL statement. + + :param param: object to determine the table from + :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`, + :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`, + or string (table name) + """ + table = self._to_table(param) + self.append('\nALTER TABLE %s ' % self.preparer.format_table(table)) + return table + + +class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator): + """Extends ansisql generator for column creation (alter table add col)""" + + def visit_column(self, column): + """Create a column (table already exists). + + :param column: column object + :type column: :class:`sqlalchemy.Column` instance + """ + if column.default is not None: + self.traverse_single(column.default) + + table = self.start_alter_table(column) + self.append("ADD ") + self.append(self.get_column_specification(column)) + + for cons in column.constraints: + self.traverse_single(cons) + self.execute() + + # ALTER TABLE STATEMENTS + + # add indexes and unique constraints + if column.index_name: + Index(column.index_name,column).create() + elif column.unique_name: + constraint.UniqueConstraint(column, + name=column.unique_name).create() + + # SA bounds FK constraints to table, add manually + for fk in column.foreign_keys: + self.add_foreignkey(fk.constraint) + + # add primary key constraint if needed + if column.primary_key_name: + cons = constraint.PrimaryKeyConstraint(column, + name=column.primary_key_name) + cons.create() + + if SQLA_06: + def add_foreignkey(self, fk): + self.connection.execute(AddConstraint(fk)) + +class ANSIColumnDropper(AlterTableVisitor, SchemaDropper): + """Extends ANSI SQL dropper for column dropping (``ALTER TABLE + DROP COLUMN``). + """ + + def visit_column(self, column): + """Drop a column from its table. + + :param column: the column object + :type column: :class:`sqlalchemy.Column` + """ + table = self.start_alter_table(column) + self.append('DROP COLUMN %s' % self.preparer.format_column(column)) + self.execute() + + +class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator): + """Manages changes to existing schema elements. + + Note that columns are schema elements; ``ALTER TABLE ADD COLUMN`` + is in SchemaGenerator. + + All items may be renamed. Columns can also have many of their properties - + type, for example - changed. + + Each function is passed a tuple, containing (object, name); where + object is a type of object you'd expect for that function + (ie. table for visit_table) and name is the object's new + name. NONE means the name is unchanged. + """ + + def visit_table(self, table): + """Rename a table. Other ops aren't supported.""" + self.start_alter_table(table) + self.append("RENAME TO %s" % self.preparer.quote(table.new_name, + table.quote)) + self.execute() + + def visit_index(self, index): + """Rename an index""" + if hasattr(self, '_validate_identifier'): + # SA <= 0.6.3 + self.append("ALTER INDEX %s RENAME TO %s" % ( + self.preparer.quote( + self._validate_identifier( + index.name, True), index.quote), + self.preparer.quote( + self._validate_identifier( + index.new_name, True), index.quote))) + else: + # SA >= 0.6.5 + self.append("ALTER INDEX %s RENAME TO %s" % ( + self.preparer.quote( + self._index_identifier( + index.name), index.quote), + self.preparer.quote( + self._index_identifier( + index.new_name), index.quote))) + self.execute() + + def visit_column(self, delta): + """Rename/change a column.""" + # ALTER COLUMN is implemented as several ALTER statements + keys = delta.keys() + if 'type' in keys: + self._run_subvisit(delta, self._visit_column_type) + if 'nullable' in keys: + self._run_subvisit(delta, self._visit_column_nullable) + if 'server_default' in keys: + # Skip 'default': only handle server-side defaults, others + # are managed by the app, not the db. + self._run_subvisit(delta, self._visit_column_default) + if 'name' in keys: + self._run_subvisit(delta, self._visit_column_name, start_alter=False) + + def _run_subvisit(self, delta, func, start_alter=True): + """Runs visit method based on what needs to be changed on column""" + table = self._to_table(delta.table) + col_name = delta.current_name + if start_alter: + self.start_alter_column(table, col_name) + ret = func(table, delta.result_column, delta) + self.execute() + + def start_alter_column(self, table, col_name): + """Starts ALTER COLUMN""" + self.start_alter_table(table) + self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, table.quote)) + + def _visit_column_nullable(self, table, column, delta): + nullable = delta['nullable'] + if nullable: + self.append("DROP NOT NULL") + else: + self.append("SET NOT NULL") + + def _visit_column_default(self, table, column, delta): + default_text = self.get_column_default_string(column) + if default_text is not None: + self.append("SET DEFAULT %s" % default_text) + else: + self.append("DROP DEFAULT") + + def _visit_column_type(self, table, column, delta): + type_ = delta['type'] + if SQLA_06: + type_text = str(type_.compile(dialect=self.dialect)) + else: + type_text = type_.dialect_impl(self.dialect).get_col_spec() + self.append("TYPE %s" % type_text) + + def _visit_column_name(self, table, column, delta): + self.start_alter_table(table) + col_name = self.preparer.quote(delta.current_name, table.quote) + new_name = self.preparer.format_column(delta.result_column) + self.append('RENAME COLUMN %s TO %s' % (col_name, new_name)) + + +class ANSIConstraintCommon(AlterTableVisitor): + """ + Migrate's constraints require a separate creation function from + SA's: Migrate's constraints are created independently of a table; + SA's are created at the same time as the table. + """ + + def get_constraint_name(self, cons): + """Gets a name for the given constraint. + + If the name is already set it will be used otherwise the + constraint's :meth:`autoname ` + method is used. + + :param cons: constraint object + """ + if cons.name is not None: + ret = cons.name + else: + ret = cons.name = cons.autoname() + return self.preparer.quote(ret, cons.quote) + + def visit_migrate_primary_key_constraint(self, *p, **k): + self._visit_constraint(*p, **k) + + def visit_migrate_foreign_key_constraint(self, *p, **k): + self._visit_constraint(*p, **k) + + def visit_migrate_check_constraint(self, *p, **k): + self._visit_constraint(*p, **k) + + def visit_migrate_unique_constraint(self, *p, **k): + self._visit_constraint(*p, **k) + +if SQLA_06: + class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator): + def _visit_constraint(self, constraint): + constraint.name = self.get_constraint_name(constraint) + self.append(self.process(AddConstraint(constraint))) + self.execute() + + class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper): + def _visit_constraint(self, constraint): + constraint.name = self.get_constraint_name(constraint) + self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade))) + self.execute() + +else: + class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator): + + def get_constraint_specification(self, cons, **kwargs): + """Constaint SQL generators. + + We cannot use SA visitors because they append comma. + """ + + if isinstance(cons, PrimaryKeyConstraint): + if cons.name is not None: + self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons)) + self.append("PRIMARY KEY ") + self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) + for c in cons)) + self.define_constraint_deferrability(cons) + elif isinstance(cons, ForeignKeyConstraint): + self.define_foreign_key(cons) + elif isinstance(cons, CheckConstraint): + if cons.name is not None: + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(cons)) + self.append("CHECK (%s)" % cons.sqltext) + self.define_constraint_deferrability(cons) + elif isinstance(cons, UniqueConstraint): + if cons.name is not None: + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(cons)) + self.append("UNIQUE (%s)" % \ + (', '.join(self.preparer.quote(c.name, c.quote) for c in cons))) + self.define_constraint_deferrability(cons) + else: + raise exceptions.InvalidConstraintError(cons) + + def _visit_constraint(self, constraint): + + table = self.start_alter_table(constraint) + constraint.name = self.get_constraint_name(constraint) + self.append("ADD ") + self.get_constraint_specification(constraint) + self.execute() + + + class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper): + + def _visit_constraint(self, constraint): + self.start_alter_table(constraint) + self.append("DROP CONSTRAINT ") + constraint.name = self.get_constraint_name(constraint) + self.append(self.preparer.format_constraint(constraint)) + if constraint.cascade: + self.cascade_constraint(constraint) + self.execute() + + def cascade_constraint(self, constraint): + self.append(" CASCADE") + + +class ANSIDialect(DefaultDialect): + columngenerator = ANSIColumnGenerator + columndropper = ANSIColumnDropper + schemachanger = ANSISchemaChanger + constraintgenerator = ANSIConstraintGenerator + constraintdropper = ANSIConstraintDropper diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/constraint.py b/rhodecode/lib/dbmigrate/migrate/changeset/constraint.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/constraint.py @@ -0,0 +1,202 @@ +""" + This module defines standalone schema constraint classes. +""" +from sqlalchemy import schema + +from migrate.exceptions import * +from migrate.changeset import SQLA_06 + +class ConstraintChangeset(object): + """Base class for Constraint classes.""" + + def _normalize_columns(self, cols, table_name=False): + """Given: column objects or names; return col names and + (maybe) a table""" + colnames = [] + table = None + for col in cols: + if isinstance(col, schema.Column): + if col.table is not None and table is None: + table = col.table + if table_name: + col = '.'.join((col.table.name, col.name)) + else: + col = col.name + colnames.append(col) + return colnames, table + + def __do_imports(self, visitor_name, *a, **kw): + engine = kw.pop('engine', self.table.bind) + from migrate.changeset.databases.visitor import (get_engine_visitor, + run_single_visitor) + visitorcallable = get_engine_visitor(engine, visitor_name) + run_single_visitor(engine, visitorcallable, self, *a, **kw) + + def create(self, *a, **kw): + """Create the constraint in the database. + + :param engine: the database engine to use. If this is \ + :keyword:`None` the instance's engine will be used + :type engine: :class:`sqlalchemy.engine.base.Engine` + :param connection: reuse connection istead of creating new one. + :type connection: :class:`sqlalchemy.engine.base.Connection` instance + """ + # TODO: set the parent here instead of in __init__ + self.__do_imports('constraintgenerator', *a, **kw) + + def drop(self, *a, **kw): + """Drop the constraint from the database. + + :param engine: the database engine to use. If this is + :keyword:`None` the instance's engine will be used + :param cascade: Issue CASCADE drop if database supports it + :type engine: :class:`sqlalchemy.engine.base.Engine` + :type cascade: bool + :param connection: reuse connection istead of creating new one. + :type connection: :class:`sqlalchemy.engine.base.Connection` instance + :returns: Instance with cleared columns + """ + self.cascade = kw.pop('cascade', False) + self.__do_imports('constraintdropper', *a, **kw) + # the spirit of Constraint objects is that they + # are immutable (just like in a DB. they're only ADDed + # or DROPped). + #self.columns.clear() + return self + + +class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint): + """Construct PrimaryKeyConstraint + + Migrate's additional parameters: + + :param cols: Columns in constraint. + :param table: If columns are passed as strings, this kw is required + :type table: Table instance + :type cols: strings or Column instances + """ + + __migrate_visit_name__ = 'migrate_primary_key_constraint' + + def __init__(self, *cols, **kwargs): + colnames, table = self._normalize_columns(cols) + table = kwargs.pop('table', table) + super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs) + if table is not None: + self._set_parent(table) + + + def autoname(self): + """Mimic the database's automatic constraint names""" + return "%s_pkey" % self.table.name + + +class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint): + """Construct ForeignKeyConstraint + + Migrate's additional parameters: + + :param columns: Columns in constraint + :param refcolumns: Columns that this FK reffers to in another table. + :param table: If columns are passed as strings, this kw is required + :type table: Table instance + :type columns: list of strings or Column instances + :type refcolumns: list of strings or Column instances + """ + + __migrate_visit_name__ = 'migrate_foreign_key_constraint' + + def __init__(self, columns, refcolumns, *args, **kwargs): + colnames, table = self._normalize_columns(columns) + table = kwargs.pop('table', table) + refcolnames, reftable = self._normalize_columns(refcolumns, + table_name=True) + super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *args, + **kwargs) + if table is not None: + self._set_parent(table) + + @property + def referenced(self): + return [e.column for e in self.elements] + + @property + def reftable(self): + return self.referenced[0].table + + def autoname(self): + """Mimic the database's automatic constraint names""" + if hasattr(self.columns, 'keys'): + # SA <= 0.5 + firstcol = self.columns[self.columns.keys()[0]] + ret = "%(table)s_%(firstcolumn)s_fkey" % dict( + table=firstcol.table.name, + firstcolumn=firstcol.name,) + else: + # SA >= 0.6 + ret = "%(table)s_%(firstcolumn)s_fkey" % dict( + table=self.table.name, + firstcolumn=self.columns[0],) + return ret + + +class CheckConstraint(ConstraintChangeset, schema.CheckConstraint): + """Construct CheckConstraint + + Migrate's additional parameters: + + :param sqltext: Plain SQL text to check condition + :param columns: If not name is applied, you must supply this kw\ + to autoname constraint + :param table: If columns are passed as strings, this kw is required + :type table: Table instance + :type columns: list of Columns instances + :type sqltext: string + """ + + __migrate_visit_name__ = 'migrate_check_constraint' + + def __init__(self, sqltext, *args, **kwargs): + cols = kwargs.pop('columns', []) + if not cols and not kwargs.get('name', False): + raise InvalidConstraintError('You must either set "name"' + 'parameter or "columns" to autogenarate it.') + colnames, table = self._normalize_columns(cols) + table = kwargs.pop('table', table) + schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs) + if table is not None: + if not SQLA_06: + self.table = table + self._set_parent(table) + self.colnames = colnames + + def autoname(self): + return "%(table)s_%(cols)s_check" % \ + dict(table=self.table.name, cols="_".join(self.colnames)) + + +class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint): + """Construct UniqueConstraint + + Migrate's additional parameters: + + :param cols: Columns in constraint. + :param table: If columns are passed as strings, this kw is required + :type table: Table instance + :type cols: strings or Column instances + + .. versionadded:: 0.6.0 + """ + + __migrate_visit_name__ = 'migrate_unique_constraint' + + def __init__(self, *cols, **kwargs): + self.colnames, table = self._normalize_columns(cols) + table = kwargs.pop('table', table) + super(UniqueConstraint, self).__init__(*self.colnames, **kwargs) + if table is not None: + self._set_parent(table) + + def autoname(self): + """Mimic the database's automatic constraint names""" + return "%s_%s_key" % (self.table.name, self.colnames[0]) diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/databases/__init__.py b/rhodecode/lib/dbmigrate/migrate/changeset/databases/__init__.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/databases/__init__.py @@ -0,0 +1,10 @@ +""" + This module contains database dialect specific changeset + implementations. +""" +__all__ = [ + 'postgres', + 'sqlite', + 'mysql', + 'oracle', +] diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/databases/firebird.py b/rhodecode/lib/dbmigrate/migrate/changeset/databases/firebird.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/databases/firebird.py @@ -0,0 +1,80 @@ +""" + Firebird database specific implementations of changeset classes. +""" +from sqlalchemy.databases import firebird as sa_base + +from migrate import exceptions +from migrate.changeset import ansisql, SQLA_06 + + +if SQLA_06: + FBSchemaGenerator = sa_base.FBDDLCompiler +else: + FBSchemaGenerator = sa_base.FBSchemaGenerator + +class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator): + """Firebird column generator implementation.""" + + +class FBColumnDropper(ansisql.ANSIColumnDropper): + """Firebird column dropper implementation.""" + + def visit_column(self, column): + """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax + + Drop primary key and unique constraints if dropped column is referencing it.""" + if column.primary_key: + if column.table.primary_key.columns.contains_column(column): + column.table.primary_key.drop() + # TODO: recreate primary key if it references more than this column + if column.unique or getattr(column, 'unique_name', None): + for cons in column.table.constraints: + if cons.contains_column(column): + cons.drop() + # TODO: recreate unique constraint if it refenrences more than this column + + table = self.start_alter_table(column) + self.append('DROP %s' % self.preparer.format_column(column)) + self.execute() + + +class FBSchemaChanger(ansisql.ANSISchemaChanger): + """Firebird schema changer implementation.""" + + def visit_table(self, table): + """Rename table not supported""" + raise exceptions.NotSupportedError( + "Firebird does not support renaming tables.") + + def _visit_column_name(self, table, column, delta): + self.start_alter_table(table) + col_name = self.preparer.quote(delta.current_name, table.quote) + new_name = self.preparer.format_column(delta.result_column) + self.append('ALTER COLUMN %s TO %s' % (col_name, new_name)) + + def _visit_column_nullable(self, table, column, delta): + """Changing NULL is not supported""" + # TODO: http://www.firebirdfaq.org/faq103/ + raise exceptions.NotSupportedError( + "Firebird does not support altering NULL bevahior.") + + +class FBConstraintGenerator(ansisql.ANSIConstraintGenerator): + """Firebird constraint generator implementation.""" + + +class FBConstraintDropper(ansisql.ANSIConstraintDropper): + """Firebird constaint dropper implementation.""" + + def cascade_constraint(self, constraint): + """Cascading constraints is not supported""" + raise exceptions.NotSupportedError( + "Firebird does not support cascading constraints") + + +class FBDialect(ansisql.ANSIDialect): + columngenerator = FBColumnGenerator + columndropper = FBColumnDropper + schemachanger = FBSchemaChanger + constraintgenerator = FBConstraintGenerator + constraintdropper = FBConstraintDropper diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/databases/mysql.py b/rhodecode/lib/dbmigrate/migrate/changeset/databases/mysql.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/databases/mysql.py @@ -0,0 +1,94 @@ +""" + MySQL database specific implementations of changeset classes. +""" + +from sqlalchemy.databases import mysql as sa_base +from sqlalchemy import types as sqltypes + +from migrate import exceptions +from migrate.changeset import ansisql, SQLA_06 + + +if not SQLA_06: + MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator +else: + MySQLSchemaGenerator = sa_base.MySQLDDLCompiler + +class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator): + pass + + +class MySQLColumnDropper(ansisql.ANSIColumnDropper): + pass + + +class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger): + + def visit_column(self, delta): + table = delta.table + colspec = self.get_column_specification(delta.result_column) + if delta.result_column.autoincrement: + primary_keys = [c for c in table.primary_key.columns + if (c.autoincrement and + isinstance(c.type, sqltypes.Integer) and + not c.foreign_keys)] + + if primary_keys: + first = primary_keys.pop(0) + if first.name == delta.current_name: + colspec += " AUTO_INCREMENT" + old_col_name = self.preparer.quote(delta.current_name, table.quote) + + self.start_alter_table(table) + + self.append("CHANGE COLUMN %s " % old_col_name) + self.append(colspec) + self.execute() + + def visit_index(self, param): + # If MySQL can do this, I can't find how + raise exceptions.NotSupportedError("MySQL cannot rename indexes") + + +class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator): + pass + +if SQLA_06: + class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper): + def visit_migrate_check_constraint(self, *p, **k): + raise exceptions.NotSupportedError("MySQL does not support CHECK" + " constraints, use triggers instead.") + +else: + class MySQLConstraintDropper(ansisql.ANSIConstraintDropper): + + def visit_migrate_primary_key_constraint(self, constraint): + self.start_alter_table(constraint) + self.append("DROP PRIMARY KEY") + self.execute() + + def visit_migrate_foreign_key_constraint(self, constraint): + self.start_alter_table(constraint) + self.append("DROP FOREIGN KEY ") + constraint.name = self.get_constraint_name(constraint) + self.append(self.preparer.format_constraint(constraint)) + self.execute() + + def visit_migrate_check_constraint(self, *p, **k): + raise exceptions.NotSupportedError("MySQL does not support CHECK" + " constraints, use triggers instead.") + + def visit_migrate_unique_constraint(self, constraint, *p, **k): + self.start_alter_table(constraint) + self.append('DROP INDEX ') + constraint.name = self.get_constraint_name(constraint) + self.append(self.preparer.format_constraint(constraint)) + self.execute() + + +class MySQLDialect(ansisql.ANSIDialect): + columngenerator = MySQLColumnGenerator + columndropper = MySQLColumnDropper + schemachanger = MySQLSchemaChanger + constraintgenerator = MySQLConstraintGenerator + constraintdropper = MySQLConstraintDropper diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/databases/oracle.py b/rhodecode/lib/dbmigrate/migrate/changeset/databases/oracle.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/databases/oracle.py @@ -0,0 +1,111 @@ +""" + Oracle database specific implementations of changeset classes. +""" +import sqlalchemy as sa +from sqlalchemy.databases import oracle as sa_base + +from migrate import exceptions +from migrate.changeset import ansisql, SQLA_06 + + +if not SQLA_06: + OracleSchemaGenerator = sa_base.OracleSchemaGenerator +else: + OracleSchemaGenerator = sa_base.OracleDDLCompiler + + +class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator): + pass + + +class OracleColumnDropper(ansisql.ANSIColumnDropper): + pass + + +class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger): + + def get_column_specification(self, column, **kwargs): + # Ignore the NOT NULL generated + override_nullable = kwargs.pop('override_nullable', None) + if override_nullable: + orig = column.nullable + column.nullable = True + ret = super(OracleSchemaChanger, self).get_column_specification( + column, **kwargs) + if override_nullable: + column.nullable = orig + return ret + + def visit_column(self, delta): + keys = delta.keys() + + if 'name' in keys: + self._run_subvisit(delta, + self._visit_column_name, + start_alter=False) + + if len(set(('type', 'nullable', 'server_default')).intersection(keys)): + self._run_subvisit(delta, + self._visit_column_change, + start_alter=False) + + def _visit_column_change(self, table, column, delta): + # Oracle cannot drop a default once created, but it can set it + # to null. We'll do that if default=None + # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234 + dropdefault_hack = (column.server_default is None \ + and 'server_default' in delta.keys()) + # Oracle apparently doesn't like it when we say "not null" if + # the column's already not null. Fudge it, so we don't need a + # new function + notnull_hack = ((not column.nullable) \ + and ('nullable' not in delta.keys())) + # We need to specify NULL if we're removing a NOT NULL + # constraint + null_hack = (column.nullable and ('nullable' in delta.keys())) + + if dropdefault_hack: + column.server_default = sa.PassiveDefault(sa.sql.null()) + if notnull_hack: + column.nullable = True + colspec = self.get_column_specification(column, + override_nullable=null_hack) + if null_hack: + colspec += ' NULL' + if notnull_hack: + column.nullable = False + if dropdefault_hack: + column.server_default = None + + self.start_alter_table(table) + self.append("MODIFY (") + self.append(colspec) + self.append(")") + + +class OracleConstraintCommon(object): + + def get_constraint_name(self, cons): + # Oracle constraints can't guess their name like other DBs + if not cons.name: + raise exceptions.NotSupportedError( + "Oracle constraint names must be explicitly stated") + return cons.name + + +class OracleConstraintGenerator(OracleConstraintCommon, + ansisql.ANSIConstraintGenerator): + pass + + +class OracleConstraintDropper(OracleConstraintCommon, + ansisql.ANSIConstraintDropper): + pass + + +class OracleDialect(ansisql.ANSIDialect): + columngenerator = OracleColumnGenerator + columndropper = OracleColumnDropper + schemachanger = OracleSchemaChanger + constraintgenerator = OracleConstraintGenerator + constraintdropper = OracleConstraintDropper diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/databases/postgres.py b/rhodecode/lib/dbmigrate/migrate/changeset/databases/postgres.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/databases/postgres.py @@ -0,0 +1,46 @@ +""" + `PostgreSQL`_ database specific implementations of changeset classes. + + .. _`PostgreSQL`: http://www.postgresql.org/ +""" +from migrate.changeset import ansisql, SQLA_06 + +if not SQLA_06: + from sqlalchemy.databases import postgres as sa_base + PGSchemaGenerator = sa_base.PGSchemaGenerator +else: + from sqlalchemy.databases import postgresql as sa_base + PGSchemaGenerator = sa_base.PGDDLCompiler + + +class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator): + """PostgreSQL column generator implementation.""" + pass + + +class PGColumnDropper(ansisql.ANSIColumnDropper): + """PostgreSQL column dropper implementation.""" + pass + + +class PGSchemaChanger(ansisql.ANSISchemaChanger): + """PostgreSQL schema changer implementation.""" + pass + + +class PGConstraintGenerator(ansisql.ANSIConstraintGenerator): + """PostgreSQL constraint generator implementation.""" + pass + + +class PGConstraintDropper(ansisql.ANSIConstraintDropper): + """PostgreSQL constaint dropper implementation.""" + pass + + +class PGDialect(ansisql.ANSIDialect): + columngenerator = PGColumnGenerator + columndropper = PGColumnDropper + schemachanger = PGSchemaChanger + constraintgenerator = PGConstraintGenerator + constraintdropper = PGConstraintDropper diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/databases/sqlite.py b/rhodecode/lib/dbmigrate/migrate/changeset/databases/sqlite.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/databases/sqlite.py @@ -0,0 +1,148 @@ +""" + `SQLite`_ database specific implementations of changeset classes. + + .. _`SQLite`: http://www.sqlite.org/ +""" +from UserDict import DictMixin +from copy import copy + +from sqlalchemy.databases import sqlite as sa_base + +from migrate import exceptions +from migrate.changeset import ansisql, SQLA_06 + + +if not SQLA_06: + SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator +else: + SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler + +class SQLiteCommon(object): + + def _not_supported(self, op): + raise exceptions.NotSupportedError("SQLite does not support " + "%s; see http://www.sqlite.org/lang_altertable.html" % op) + + +class SQLiteHelper(SQLiteCommon): + + def recreate_table(self,table,column=None,delta=None): + table_name = self.preparer.format_table(table) + + # we remove all indexes so as not to have + # problems during copy and re-create + for index in table.indexes: + index.drop() + + self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name) + self.execute() + + insertion_string = self._modify_table(table, column, delta) + + table.create() + self.append(insertion_string % {'table_name': table_name}) + self.execute() + self.append('DROP TABLE migration_tmp') + self.execute() + + def visit_column(self, delta): + if isinstance(delta, DictMixin): + column = delta.result_column + table = self._to_table(delta.table) + else: + column = delta + table = self._to_table(column.table) + self.recreate_table(table,column,delta) + +class SQLiteColumnGenerator(SQLiteSchemaGenerator, + ansisql.ANSIColumnGenerator, + # at the end so we get the normal + # visit_column by default + SQLiteHelper, + SQLiteCommon + ): + """SQLite ColumnGenerator""" + + def _modify_table(self, table, column, delta): + columns = ' ,'.join(map( + self.preparer.format_column, + [c for c in table.columns if c.name!=column.name])) + return ('INSERT INTO %%(table_name)s (%(cols)s) ' + 'SELECT %(cols)s from migration_tmp')%{'cols':columns} + + def visit_column(self,column): + if column.foreign_keys: + SQLiteHelper.visit_column(self,column) + else: + super(SQLiteColumnGenerator,self).visit_column(column) + +class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper): + """SQLite ColumnDropper""" + + def _modify_table(self, table, column, delta): + columns = ' ,'.join(map(self.preparer.format_column, table.columns)) + return 'INSERT INTO %(table_name)s SELECT ' + columns + \ + ' from migration_tmp' + + +class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger): + """SQLite SchemaChanger""" + + def _modify_table(self, table, column, delta): + return 'INSERT INTO %(table_name)s SELECT * from migration_tmp' + + def visit_index(self, index): + """Does not support ALTER INDEX""" + self._not_supported('ALTER INDEX') + + +class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon): + + def visit_migrate_primary_key_constraint(self, constraint): + tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )" + cols = ', '.join(map(self.preparer.format_column, constraint.columns)) + tname = self.preparer.format_table(constraint.table) + name = self.get_constraint_name(constraint) + msg = tmpl % (name, tname, cols) + self.append(msg) + self.execute() + + def _modify_table(self, table, column, delta): + return 'INSERT INTO %(table_name)s SELECT * from migration_tmp' + + def visit_migrate_foreign_key_constraint(self, *p, **k): + self.recreate_table(p[0].table) + + def visit_migrate_unique_constraint(self, *p, **k): + self.recreate_table(p[0].table) + + +class SQLiteConstraintDropper(ansisql.ANSIColumnDropper, + SQLiteCommon, + ansisql.ANSIConstraintCommon): + + def visit_migrate_primary_key_constraint(self, constraint): + tmpl = "DROP INDEX %s " + name = self.get_constraint_name(constraint) + msg = tmpl % (name) + self.append(msg) + self.execute() + + def visit_migrate_foreign_key_constraint(self, *p, **k): + self._not_supported('ALTER TABLE DROP CONSTRAINT') + + def visit_migrate_check_constraint(self, *p, **k): + self._not_supported('ALTER TABLE DROP CONSTRAINT') + + def visit_migrate_unique_constraint(self, *p, **k): + self._not_supported('ALTER TABLE DROP CONSTRAINT') + + +# TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index + +class SQLiteDialect(ansisql.ANSIDialect): + columngenerator = SQLiteColumnGenerator + columndropper = SQLiteColumnDropper + schemachanger = SQLiteSchemaChanger + constraintgenerator = SQLiteConstraintGenerator + constraintdropper = SQLiteConstraintDropper diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/databases/visitor.py b/rhodecode/lib/dbmigrate/migrate/changeset/databases/visitor.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/databases/visitor.py @@ -0,0 +1,78 @@ +""" + Module for visitor class mapping. +""" +import sqlalchemy as sa + +from migrate.changeset import ansisql +from migrate.changeset.databases import (sqlite, + postgres, + mysql, + oracle, + firebird) + + +# Map SA dialects to the corresponding Migrate extensions +DIALECTS = { + "default": ansisql.ANSIDialect, + "sqlite": sqlite.SQLiteDialect, + "postgres": postgres.PGDialect, + "postgresql": postgres.PGDialect, + "mysql": mysql.MySQLDialect, + "oracle": oracle.OracleDialect, + "firebird": firebird.FBDialect, +} + + +def get_engine_visitor(engine, name): + """ + Get the visitor implementation for the given database engine. + + :param engine: SQLAlchemy Engine + :param name: Name of the visitor + :type name: string + :type engine: Engine + :returns: visitor + """ + # TODO: link to supported visitors + return get_dialect_visitor(engine.dialect, name) + + +def get_dialect_visitor(sa_dialect, name): + """ + Get the visitor implementation for the given dialect. + + Finds the visitor implementation based on the dialect class and + returns and instance initialized with the given name. + + Binds dialect specific preparer to visitor. + """ + + # map sa dialect to migrate dialect and return visitor + sa_dialect_name = getattr(sa_dialect, 'name', 'default') + migrate_dialect_cls = DIALECTS[sa_dialect_name] + visitor = getattr(migrate_dialect_cls, name) + + # bind preparer + visitor.preparer = sa_dialect.preparer(sa_dialect) + + return visitor + +def run_single_visitor(engine, visitorcallable, element, + connection=None, **kwargs): + """Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor` + with support for migrate visitors. + """ + if connection is None: + conn = engine.contextual_connect(close_with_result=False) + else: + conn = connection + visitor = visitorcallable(engine.dialect, conn) + try: + if hasattr(element, '__migrate_visit_name__'): + fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__) + else: + fn = getattr(visitor, 'visit_' + element.__visit_name__) + fn(element, **kwargs) + finally: + if connection is None: + conn.close() diff --git a/rhodecode/lib/dbmigrate/migrate/changeset/schema.py b/rhodecode/lib/dbmigrate/migrate/changeset/schema.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/changeset/schema.py @@ -0,0 +1,669 @@ +""" + Schema module providing common schema operations. +""" +import warnings + +from UserDict import DictMixin + +import sqlalchemy + +from sqlalchemy.schema import ForeignKeyConstraint +from sqlalchemy.schema import UniqueConstraint + +from migrate.exceptions import * +from migrate.changeset import SQLA_06 +from migrate.changeset.databases.visitor import (get_engine_visitor, + run_single_visitor) + + +__all__ = [ + 'create_column', + 'drop_column', + 'alter_column', + 'rename_table', + 'rename_index', + 'ChangesetTable', + 'ChangesetColumn', + 'ChangesetIndex', + 'ChangesetDefaultClause', + 'ColumnDelta', +] + +DEFAULT_ALTER_METADATA = True + + +def create_column(column, table=None, *p, **kw): + """Create a column, given the table. + + API to :meth:`ChangesetColumn.create`. + """ + if table is not None: + return table.create_column(column, *p, **kw) + return column.create(*p, **kw) + + +def drop_column(column, table=None, *p, **kw): + """Drop a column, given the table. + + API to :meth:`ChangesetColumn.drop`. + """ + if table is not None: + return table.drop_column(column, *p, **kw) + return column.drop(*p, **kw) + + +def rename_table(table, name, engine=None, **kw): + """Rename a table. + + If Table instance is given, engine is not used. + + API to :meth:`ChangesetTable.rename`. + + :param table: Table to be renamed. + :param name: New name for Table. + :param engine: Engine instance. + :type table: string or Table instance + :type name: string + :type engine: obj + """ + table = _to_table(table, engine) + table.rename(name, **kw) + + +def rename_index(index, name, table=None, engine=None, **kw): + """Rename an index. + + If Index instance is given, + table and engine are not used. + + API to :meth:`ChangesetIndex.rename`. + + :param index: Index to be renamed. + :param name: New name for index. + :param table: Table to which Index is reffered. + :param engine: Engine instance. + :type index: string or Index instance + :type name: string + :type table: string or Table instance + :type engine: obj + """ + index = _to_index(index, table, engine) + index.rename(name, **kw) + + +def alter_column(*p, **k): + """Alter a column. + + This is a helper function that creates a :class:`ColumnDelta` and + runs it. + + :argument column: + The name of the column to be altered or a + :class:`ChangesetColumn` column representing it. + + :param table: + A :class:`~sqlalchemy.schema.Table` or table name to + for the table where the column will be changed. + + :param engine: + The :class:`~sqlalchemy.engine.base.Engine` to use for table + reflection and schema alterations. + + :param alter_metadata: + If `True`, which is the default, the + :class:`~sqlalchemy.schema.Column` will also modified. + If `False`, the :class:`~sqlalchemy.schema.Column` will be left + as it was. + + :returns: A :class:`ColumnDelta` instance representing the change. + + + """ + + k.setdefault('alter_metadata', DEFAULT_ALTER_METADATA) + + if 'table' not in k and isinstance(p[0], sqlalchemy.Column): + k['table'] = p[0].table + if 'engine' not in k: + k['engine'] = k['table'].bind + + # deprecation + if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column): + warnings.warn( + "Passing a Column object to alter_column is deprecated." + " Just pass in keyword parameters instead.", + MigrateDeprecationWarning + ) + engine = k['engine'] + delta = ColumnDelta(*p, **k) + + visitorcallable = get_engine_visitor(engine, 'schemachanger') + engine._run_visitor(visitorcallable, delta) + + return delta + + +def _to_table(table, engine=None): + """Return if instance of Table, else construct new with metadata""" + if isinstance(table, sqlalchemy.Table): + return table + + # Given: table name, maybe an engine + meta = sqlalchemy.MetaData() + if engine is not None: + meta.bind = engine + return sqlalchemy.Table(table, meta) + + +def _to_index(index, table=None, engine=None): + """Return if instance of Index, else construct new with metadata""" + if isinstance(index, sqlalchemy.Index): + return index + + # Given: index name; table name required + table = _to_table(table, engine) + ret = sqlalchemy.Index(index) + ret.table = table + return ret + + +class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem): + """Extracts the differences between two columns/column-parameters + + May receive parameters arranged in several different ways: + + * **current_column, new_column, \*p, \*\*kw** + Additional parameters can be specified to override column + differences. + + * **current_column, \*p, \*\*kw** + Additional parameters alter current_column. Table name is extracted + from current_column object. + Name is changed to current_column.name from current_name, + if current_name is specified. + + * **current_col_name, \*p, \*\*kw** + Table kw must specified. + + :param table: Table at which current Column should be bound to.\ + If table name is given, reflection will be used. + :type table: string or Table instance + :param alter_metadata: If True, it will apply changes to metadata. + :type alter_metadata: bool + :param metadata: If `alter_metadata` is true, \ + metadata is used to reflect table names into + :type metadata: :class:`MetaData` instance + :param engine: When reflecting tables, either engine or metadata must \ + be specified to acquire engine object. + :type engine: :class:`Engine` instance + :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \ + `result_column` through :func:`dict` alike object. + + * :class:`ColumnDelta`.result_column is altered column with new attributes + + * :class:`ColumnDelta`.current_name is current name of column in db + + + """ + + # Column attributes that can be altered + diff_keys = ('name', 'type', 'primary_key', 'nullable', + 'server_onupdate', 'server_default', 'autoincrement') + diffs = dict() + __visit_name__ = 'column' + + def __init__(self, *p, **kw): + self.alter_metadata = kw.pop("alter_metadata", False) + self.meta = kw.pop("metadata", None) + self.engine = kw.pop("engine", None) + + # Things are initialized differently depending on how many column + # parameters are given. Figure out how many and call the appropriate + # method. + if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column): + # At least one column specified + if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column): + # Two columns specified + diffs = self.compare_2_columns(*p, **kw) + else: + # Exactly one column specified + diffs = self.compare_1_column(*p, **kw) + else: + # Zero columns specified + if not len(p) or not isinstance(p[0], basestring): + raise ValueError("First argument must be column name") + diffs = self.compare_parameters(*p, **kw) + + self.apply_diffs(diffs) + + def __repr__(self): + return '' % (self.alter_metadata, + super(ColumnDelta, self).__repr__()) + + def __getitem__(self, key): + if key not in self.keys(): + raise KeyError("No such diff key, available: %s" % self.diffs) + return getattr(self.result_column, key) + + def __setitem__(self, key, value): + if key not in self.keys(): + raise KeyError("No such diff key, available: %s" % self.diffs) + setattr(self.result_column, key, value) + + def __delitem__(self, key): + raise NotImplementedError + + def keys(self): + return self.diffs.keys() + + def compare_parameters(self, current_name, *p, **k): + """Compares Column objects with reflection""" + self.table = k.pop('table') + self.result_column = self._table.c.get(current_name) + if len(p): + k = self._extract_parameters(p, k, self.result_column) + return k + + def compare_1_column(self, col, *p, **k): + """Compares one Column object""" + self.table = k.pop('table', None) + if self.table is None: + self.table = col.table + self.result_column = col + if len(p): + k = self._extract_parameters(p, k, self.result_column) + return k + + def compare_2_columns(self, old_col, new_col, *p, **k): + """Compares two Column objects""" + self.process_column(new_col) + self.table = k.pop('table', None) + # we cannot use bool() on table in SA06 + if self.table is None: + self.table = old_col.table + if self.table is None: + new_col.table + self.result_column = old_col + + # set differences + # leave out some stuff for later comp + for key in (set(self.diff_keys) - set(('type',))): + val = getattr(new_col, key, None) + if getattr(self.result_column, key, None) != val: + k.setdefault(key, val) + + # inspect types + if not self.are_column_types_eq(self.result_column.type, new_col.type): + k.setdefault('type', new_col.type) + + if len(p): + k = self._extract_parameters(p, k, self.result_column) + return k + + def apply_diffs(self, diffs): + """Populate dict and column object with new values""" + self.diffs = diffs + for key in self.diff_keys: + if key in diffs: + setattr(self.result_column, key, diffs[key]) + + self.process_column(self.result_column) + + # create an instance of class type if not yet + if 'type' in diffs and callable(self.result_column.type): + self.result_column.type = self.result_column.type() + + # add column to the table + if self.table is not None and self.alter_metadata: + self.result_column.add_to_table(self.table) + + def are_column_types_eq(self, old_type, new_type): + """Compares two types to be equal""" + ret = old_type.__class__ == new_type.__class__ + + # String length is a special case + if ret and isinstance(new_type, sqlalchemy.types.String): + ret = (getattr(old_type, 'length', None) == \ + getattr(new_type, 'length', None)) + return ret + + def _extract_parameters(self, p, k, column): + """Extracts data from p and modifies diffs""" + p = list(p) + while len(p): + if isinstance(p[0], basestring): + k.setdefault('name', p.pop(0)) + elif isinstance(p[0], sqlalchemy.types.AbstractType): + k.setdefault('type', p.pop(0)) + elif callable(p[0]): + p[0] = p[0]() + else: + break + + if len(p): + new_col = column.copy_fixed() + new_col._init_items(*p) + k = self.compare_2_columns(column, new_col, **k) + return k + + def process_column(self, column): + """Processes default values for column""" + # XXX: this is a snippet from SA processing of positional parameters + if not SQLA_06 and column.args: + toinit = list(column.args) + else: + toinit = list() + + if column.server_default is not None: + if isinstance(column.server_default, sqlalchemy.FetchedValue): + toinit.append(column.server_default) + else: + toinit.append(sqlalchemy.DefaultClause(column.server_default)) + if column.server_onupdate is not None: + if isinstance(column.server_onupdate, FetchedValue): + toinit.append(column.server_default) + else: + toinit.append(sqlalchemy.DefaultClause(column.server_onupdate, + for_update=True)) + if toinit: + column._init_items(*toinit) + + if not SQLA_06: + column.args = [] + + def _get_table(self): + return getattr(self, '_table', None) + + def _set_table(self, table): + if isinstance(table, basestring): + if self.alter_metadata: + if not self.meta: + raise ValueError("metadata must be specified for table" + " reflection when using alter_metadata") + meta = self.meta + if self.engine: + meta.bind = self.engine + else: + if not self.engine and not self.meta: + raise ValueError("engine or metadata must be specified" + " to reflect tables") + if not self.engine: + self.engine = self.meta.bind + meta = sqlalchemy.MetaData(bind=self.engine) + self._table = sqlalchemy.Table(table, meta, autoload=True) + elif isinstance(table, sqlalchemy.Table): + self._table = table + if not self.alter_metadata: + self._table.meta = sqlalchemy.MetaData(bind=self._table.bind) + + def _get_result_column(self): + return getattr(self, '_result_column', None) + + def _set_result_column(self, column): + """Set Column to Table based on alter_metadata evaluation.""" + self.process_column(column) + if not hasattr(self, 'current_name'): + self.current_name = column.name + if self.alter_metadata: + self._result_column = column + else: + self._result_column = column.copy_fixed() + + table = property(_get_table, _set_table) + result_column = property(_get_result_column, _set_result_column) + + +class ChangesetTable(object): + """Changeset extensions to SQLAlchemy tables.""" + + def create_column(self, column, *p, **kw): + """Creates a column. + + The column parameter may be a column definition or the name of + a column in this table. + + API to :meth:`ChangesetColumn.create` + + :param column: Column to be created + :type column: Column instance or string + """ + if not isinstance(column, sqlalchemy.Column): + # It's a column name + column = getattr(self.c, str(column)) + column.create(table=self, *p, **kw) + + def drop_column(self, column, *p, **kw): + """Drop a column, given its name or definition. + + API to :meth:`ChangesetColumn.drop` + + :param column: Column to be droped + :type column: Column instance or string + """ + if not isinstance(column, sqlalchemy.Column): + # It's a column name + try: + column = getattr(self.c, str(column)) + except AttributeError: + # That column isn't part of the table. We don't need + # its entire definition to drop the column, just its + # name, so create a dummy column with the same name. + column = sqlalchemy.Column(str(column), sqlalchemy.Integer()) + column.drop(table=self, *p, **kw) + + def rename(self, name, connection=None, **kwargs): + """Rename this table. + + :param name: New name of the table. + :type name: string + :param alter_metadata: If True, table will be removed from metadata + :type alter_metadata: bool + :param connection: reuse connection istead of creating new one. + :type connection: :class:`sqlalchemy.engine.base.Connection` instance + """ + self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA) + engine = self.bind + self.new_name = name + visitorcallable = get_engine_visitor(engine, 'schemachanger') + run_single_visitor(engine, visitorcallable, self, connection, **kwargs) + + # Fix metadata registration + if self.alter_metadata: + self.name = name + self.deregister() + self._set_parent(self.metadata) + + def _meta_key(self): + return sqlalchemy.schema._get_table_key(self.name, self.schema) + + def deregister(self): + """Remove this table from its metadata""" + key = self._meta_key() + meta = self.metadata + if key in meta.tables: + del meta.tables[key] + + +class ChangesetColumn(object): + """Changeset extensions to SQLAlchemy columns.""" + + def alter(self, *p, **k): + """Makes a call to :func:`alter_column` for the column this + method is called on. + """ + if 'table' not in k: + k['table'] = self.table + if 'engine' not in k: + k['engine'] = k['table'].bind + return alter_column(self, *p, **k) + + def create(self, table=None, index_name=None, unique_name=None, + primary_key_name=None, populate_default=True, connection=None, **kwargs): + """Create this column in the database. + + Assumes the given table exists. ``ALTER TABLE ADD COLUMN``, + for most databases. + + :param table: Table instance to create on. + :param index_name: Creates :class:`ChangesetIndex` on this column. + :param unique_name: Creates :class:\ +`~migrate.changeset.constraint.UniqueConstraint` on this column. + :param primary_key_name: Creates :class:\ +`~migrate.changeset.constraint.PrimaryKeyConstraint` on this column. + :param alter_metadata: If True, column will be added to table object. + :param populate_default: If True, created column will be \ +populated with defaults + :param connection: reuse connection istead of creating new one. + :type table: Table instance + :type index_name: string + :type unique_name: string + :type primary_key_name: string + :type alter_metadata: bool + :type populate_default: bool + :type connection: :class:`sqlalchemy.engine.base.Connection` instance + + :returns: self + """ + self.populate_default = populate_default + self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA) + self.index_name = index_name + self.unique_name = unique_name + self.primary_key_name = primary_key_name + for cons in ('index_name', 'unique_name', 'primary_key_name'): + self._check_sanity_constraints(cons) + + if self.alter_metadata: + self.add_to_table(table) + engine = self.table.bind + visitorcallable = get_engine_visitor(engine, 'columngenerator') + engine._run_visitor(visitorcallable, self, connection, **kwargs) + + # TODO: reuse existing connection + if self.populate_default and self.default is not None: + stmt = table.update().values({self: engine._execute_default(self.default)}) + engine.execute(stmt) + + return self + + def drop(self, table=None, connection=None, **kwargs): + """Drop this column from the database, leaving its table intact. + + ``ALTER TABLE DROP COLUMN``, for most databases. + + :param alter_metadata: If True, column will be removed from table object. + :type alter_metadata: bool + :param connection: reuse connection istead of creating new one. + :type connection: :class:`sqlalchemy.engine.base.Connection` instance + """ + self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA) + if table is not None: + self.table = table + engine = self.table.bind + if self.alter_metadata: + self.remove_from_table(self.table, unset_table=False) + visitorcallable = get_engine_visitor(engine, 'columndropper') + engine._run_visitor(visitorcallable, self, connection, **kwargs) + if self.alter_metadata: + self.table = None + return self + + def add_to_table(self, table): + if table is not None and self.table is None: + self._set_parent(table) + + def _col_name_in_constraint(self, cons, name): + return False + + def remove_from_table(self, table, unset_table=True): + # TODO: remove primary keys, constraints, etc + if unset_table: + self.table = None + + to_drop = set() + for index in table.indexes: + columns = [] + for col in index.columns: + if col.name != self.name: + columns.append(col) + if columns: + index.columns = columns + else: + to_drop.add(index) + table.indexes = table.indexes - to_drop + + to_drop = set() + for cons in table.constraints: + # TODO: deal with other types of constraint + if isinstance(cons, (ForeignKeyConstraint, + UniqueConstraint)): + for col_name in cons.columns: + if not isinstance(col_name, basestring): + col_name = col_name.name + if self.name == col_name: + to_drop.add(cons) + table.constraints = table.constraints - to_drop + + if table.c.contains_column(self): + table.c.remove(self) + + # TODO: this is fixed in 0.6 + def copy_fixed(self, **kw): + """Create a copy of this ``Column``, with all attributes.""" + return sqlalchemy.Column(self.name, self.type, self.default, + key=self.key, + primary_key=self.primary_key, + nullable=self.nullable, + quote=self.quote, + index=self.index, + unique=self.unique, + onupdate=self.onupdate, + autoincrement=self.autoincrement, + server_default=self.server_default, + server_onupdate=self.server_onupdate, + *[c.copy(**kw) for c in self.constraints]) + + def _check_sanity_constraints(self, name): + """Check if constraints names are correct""" + obj = getattr(self, name) + if (getattr(self, name[:-5]) and not obj): + raise InvalidConstraintError("Column.create() accepts index_name," + " primary_key_name and unique_name to generate constraints") + if not isinstance(obj, basestring) and obj is not None: + raise InvalidConstraintError( + "%s argument for column must be constraint name" % name) + + +class ChangesetIndex(object): + """Changeset extensions to SQLAlchemy Indexes.""" + + __visit_name__ = 'index' + + def rename(self, name, connection=None, **kwargs): + """Change the name of an index. + + :param name: New name of the Index. + :type name: string + :param alter_metadata: If True, Index object will be altered. + :type alter_metadata: bool + :param connection: reuse connection istead of creating new one. + :type connection: :class:`sqlalchemy.engine.base.Connection` instance + """ + self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA) + engine = self.table.bind + self.new_name = name + visitorcallable = get_engine_visitor(engine, 'schemachanger') + engine._run_visitor(visitorcallable, self, connection, **kwargs) + if self.alter_metadata: + self.name = name + + +class ChangesetDefaultClause(object): + """Implements comparison between :class:`DefaultClause` instances""" + + def __eq__(self, other): + if isinstance(other, self.__class__): + if self.arg == other.arg: + return True + + def __ne__(self, other): + return not self.__eq__(other) diff --git a/rhodecode/lib/dbmigrate/migrate/exceptions.py b/rhodecode/lib/dbmigrate/migrate/exceptions.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/exceptions.py @@ -0,0 +1,87 @@ +""" + Provide exception classes for :mod:`migrate` +""" + + +class Error(Exception): + """Error base class.""" + + +class ApiError(Error): + """Base class for API errors.""" + + +class KnownError(ApiError): + """A known error condition.""" + + +class UsageError(ApiError): + """A known error condition where help should be displayed.""" + + +class ControlledSchemaError(Error): + """Base class for controlled schema errors.""" + + +class InvalidVersionError(ControlledSchemaError): + """Invalid version number.""" + + +class DatabaseNotControlledError(ControlledSchemaError): + """Database should be under version control, but it's not.""" + + +class DatabaseAlreadyControlledError(ControlledSchemaError): + """Database shouldn't be under version control, but it is""" + + +class WrongRepositoryError(ControlledSchemaError): + """This database is under version control by another repository.""" + + +class NoSuchTableError(ControlledSchemaError): + """The table does not exist.""" + + +class PathError(Error): + """Base class for path errors.""" + + +class PathNotFoundError(PathError): + """A path with no file was required; found a file.""" + + +class PathFoundError(PathError): + """A path with a file was required; found no file.""" + + +class RepositoryError(Error): + """Base class for repository errors.""" + + +class InvalidRepositoryError(RepositoryError): + """Invalid repository error.""" + + +class ScriptError(Error): + """Base class for script errors.""" + + +class InvalidScriptError(ScriptError): + """Invalid script error.""" + + +class InvalidVersionError(Error): + """Invalid version error.""" + +# migrate.changeset + +class NotSupportedError(Error): + """Not supported error""" + + +class InvalidConstraintError(Error): + """Invalid constraint error""" + +class MigrateDeprecationWarning(DeprecationWarning): + """Warning for deprecated features in Migrate""" diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/__init__.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/__init__.py @@ -0,0 +1,5 @@ +""" + This package provides functionality to create and manage + repositories of database schema changesets and to apply these + changesets to databases. +""" diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/api.py b/rhodecode/lib/dbmigrate/migrate/versioning/api.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/api.py @@ -0,0 +1,383 @@ +""" + This module provides an external API to the versioning system. + + .. versionchanged:: 0.6.0 + :func:`migrate.versioning.api.test` and schema diff functions + changed order of positional arguments so all accept `url` and `repository` + as first arguments. + + .. versionchanged:: 0.5.4 + ``--preview_sql`` displays source file when using SQL scripts. + If Python script is used, it runs the action with mocked engine and + returns captured SQL statements. + + .. versionchanged:: 0.5.4 + Deprecated ``--echo`` parameter in favour of new + :func:`migrate.versioning.util.construct_engine` behavior. +""" + +# Dear migrate developers, +# +# please do not comment this module using sphinx syntax because its +# docstrings are presented as user help and most users cannot +# interpret sphinx annotated ReStructuredText. +# +# Thanks, +# Jan Dittberner + +import sys +import inspect +import logging + +from migrate import exceptions +from migrate.versioning import (repository, schema, version, + script as script_) # command name conflict +from migrate.versioning.util import catch_known_errors, with_engine + + +log = logging.getLogger(__name__) +command_desc = { + 'help': 'displays help on a given command', + 'create': 'create an empty repository at the specified path', + 'script': 'create an empty change Python script', + 'script_sql': 'create empty change SQL scripts for given database', + 'version': 'display the latest version available in a repository', + 'db_version': 'show the current version of the repository under version control', + 'source': 'display the Python code for a particular version in this repository', + 'version_control': 'mark a database as under this repository\'s version control', + 'upgrade': 'upgrade a database to a later version', + 'downgrade': 'downgrade a database to an earlier version', + 'drop_version_control': 'removes version control from a database', + 'manage': 'creates a Python script that runs Migrate with a set of default values', + 'test': 'performs the upgrade and downgrade command on the given database', + 'compare_model_to_db': 'compare MetaData against the current database state', + 'create_model': 'dump the current database as a Python model to stdout', + 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData', + 'update_db_from_model': 'modify the database to match the structure of the current MetaData', +} +__all__ = command_desc.keys() + +Repository = repository.Repository +ControlledSchema = schema.ControlledSchema +VerNum = version.VerNum +PythonScript = script_.PythonScript +SqlScript = script_.SqlScript + + +# deprecated +def help(cmd=None, **opts): + """%prog help COMMAND + + Displays help on a given command. + """ + if cmd is None: + raise exceptions.UsageError(None) + try: + func = globals()[cmd] + except: + raise exceptions.UsageError( + "'%s' isn't a valid command. Try 'help COMMAND'" % cmd) + ret = func.__doc__ + if sys.argv[0]: + ret = ret.replace('%prog', sys.argv[0]) + return ret + +@catch_known_errors +def create(repository, name, **opts): + """%prog create REPOSITORY_PATH NAME [--table=TABLE] + + Create an empty repository at the specified path. + + You can specify the version_table to be used; by default, it is + 'migrate_version'. This table is created in all version-controlled + databases. + """ + repo_path = Repository.create(repository, name, **opts) + + +@catch_known_errors +def script(description, repository, **opts): + """%prog script DESCRIPTION REPOSITORY_PATH + + Create an empty change script using the next unused version number + appended with the given description. + + For instance, manage.py script "Add initial tables" creates: + repository/versions/001_Add_initial_tables.py + """ + repo = Repository(repository) + repo.create_script(description, **opts) + + +@catch_known_errors +def script_sql(database, repository, **opts): + """%prog script_sql DATABASE REPOSITORY_PATH + + Create empty change SQL scripts for given DATABASE, where DATABASE + is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.) + or generic ('default'). + + For instance, manage.py script_sql postgres creates: + repository/versions/001_postgres_upgrade.sql and + repository/versions/001_postgres_postgres.sql + """ + repo = Repository(repository) + repo.create_script_sql(database, **opts) + + +def version(repository, **opts): + """%prog version REPOSITORY_PATH + + Display the latest version available in a repository. + """ + repo = Repository(repository) + return repo.latest + + +@with_engine +def db_version(url, repository, **opts): + """%prog db_version URL REPOSITORY_PATH + + Show the current version of the repository with the given + connection string, under version control of the specified + repository. + + The url should be any valid SQLAlchemy connection string. + """ + engine = opts.pop('engine') + schema = ControlledSchema(engine, repository) + return schema.version + + +def source(version, dest=None, repository=None, **opts): + """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH + + Display the Python code for a particular version in this + repository. Save it to the file at DESTINATION or, if omitted, + send to stdout. + """ + if repository is None: + raise exceptions.UsageError("A repository must be specified") + repo = Repository(repository) + ret = repo.version(version).script().source() + if dest is not None: + dest = open(dest, 'w') + dest.write(ret) + dest.close() + ret = None + return ret + + +def upgrade(url, repository, version=None, **opts): + """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql] + + Upgrade a database to a later version. + + This runs the upgrade() function defined in your change scripts. + + By default, the database is updated to the latest available + version. You may specify a version instead, if you wish. + + You may preview the Python or SQL code to be executed, rather than + actually executing it, using the appropriate 'preview' option. + """ + err = "Cannot upgrade a database of version %s to version %s. "\ + "Try 'downgrade' instead." + return _migrate(url, repository, version, upgrade=True, err=err, **opts) + + +def downgrade(url, repository, version, **opts): + """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql] + + Downgrade a database to an earlier version. + + This is the reverse of upgrade; this runs the downgrade() function + defined in your change scripts. + + You may preview the Python or SQL code to be executed, rather than + actually executing it, using the appropriate 'preview' option. + """ + err = "Cannot downgrade a database of version %s to version %s. "\ + "Try 'upgrade' instead." + return _migrate(url, repository, version, upgrade=False, err=err, **opts) + +@with_engine +def test(url, repository, **opts): + """%prog test URL REPOSITORY_PATH [VERSION] + + Performs the upgrade and downgrade option on the given + database. This is not a real test and may leave the database in a + bad state. You should therefore better run the test on a copy of + your database. + """ + engine = opts.pop('engine') + repos = Repository(repository) + script = repos.version(None).script() + + # Upgrade + log.info("Upgrading...") + script.run(engine, 1) + log.info("done") + + log.info("Downgrading...") + script.run(engine, -1) + log.info("done") + log.info("Success") + + +@with_engine +def version_control(url, repository, version=None, **opts): + """%prog version_control URL REPOSITORY_PATH [VERSION] + + Mark a database as under this repository's version control. + + Once a database is under version control, schema changes should + only be done via change scripts in this repository. + + This creates the table version_table in the database. + + The url should be any valid SQLAlchemy connection string. + + By default, the database begins at version 0 and is assumed to be + empty. If the database is not empty, you may specify a version at + which to begin instead. No attempt is made to verify this + version's correctness - the database schema is expected to be + identical to what it would be if the database were created from + scratch. + """ + engine = opts.pop('engine') + ControlledSchema.create(engine, repository, version) + + +@with_engine +def drop_version_control(url, repository, **opts): + """%prog drop_version_control URL REPOSITORY_PATH + + Removes version control from a database. + """ + engine = opts.pop('engine') + schema = ControlledSchema(engine, repository) + schema.drop() + + +def manage(file, **opts): + """%prog manage FILENAME [VARIABLES...] + + Creates a script that runs Migrate with a set of default values. + + For example:: + + %prog manage manage.py --repository=/path/to/repository \ +--url=sqlite:///project.db + + would create the script manage.py. The following two commands + would then have exactly the same results:: + + python manage.py version + %prog version --repository=/path/to/repository + """ + Repository.create_manage_file(file, **opts) + + +@with_engine +def compare_model_to_db(url, repository, model, **opts): + """%prog compare_model_to_db URL REPOSITORY_PATH MODEL + + Compare the current model (assumed to be a module level variable + of type sqlalchemy.MetaData) against the current database. + + NOTE: This is EXPERIMENTAL. + """ # TODO: get rid of EXPERIMENTAL label + engine = opts.pop('engine') + return ControlledSchema.compare_model_to_db(engine, model, repository) + + +@with_engine +def create_model(url, repository, **opts): + """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True] + + Dump the current database as a Python model to stdout. + + NOTE: This is EXPERIMENTAL. + """ # TODO: get rid of EXPERIMENTAL label + engine = opts.pop('engine') + declarative = opts.get('declarative', False) + return ControlledSchema.create_model(engine, repository, declarative) + + +@catch_known_errors +@with_engine +def make_update_script_for_model(url, repository, oldmodel, model, **opts): + """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH + + Create a script changing the old Python model to the new (current) + Python model, sending to stdout. + + NOTE: This is EXPERIMENTAL. + """ # TODO: get rid of EXPERIMENTAL label + engine = opts.pop('engine') + return PythonScript.make_update_script_for_model( + engine, oldmodel, model, repository, **opts) + + +@with_engine +def update_db_from_model(url, repository, model, **opts): + """%prog update_db_from_model URL REPOSITORY_PATH MODEL + + Modify the database to match the structure of the current Python + model. This also sets the db_version number to the latest in the + repository. + + NOTE: This is EXPERIMENTAL. + """ # TODO: get rid of EXPERIMENTAL label + engine = opts.pop('engine') + schema = ControlledSchema(engine, repository) + schema.update_db_from_model(model) + +@with_engine +def _migrate(url, repository, version, upgrade, err, **opts): + engine = opts.pop('engine') + url = str(engine.url) + schema = ControlledSchema(engine, repository) + version = _migrate_version(schema, version, upgrade, err) + + changeset = schema.changeset(version) + for ver, change in changeset: + nextver = ver + changeset.step + log.info('%s -> %s... ', ver, nextver) + + if opts.get('preview_sql'): + if isinstance(change, PythonScript): + log.info(change.preview_sql(url, changeset.step, **opts)) + elif isinstance(change, SqlScript): + log.info(change.source()) + + elif opts.get('preview_py'): + if not isinstance(change, PythonScript): + raise exceptions.UsageError("Python source can be only displayed" + " for python migration files") + source_ver = max(ver, nextver) + module = schema.repository.version(source_ver).script().module + funcname = upgrade and "upgrade" or "downgrade" + func = getattr(module, funcname) + log.info(inspect.getsource(func)) + else: + schema.runchange(ver, change, changeset.step) + log.info('done') + + +def _migrate_version(schema, version, upgrade, err): + if version is None: + return version + # Version is specified: ensure we're upgrading in the right direction + # (current version < target version for upgrading; reverse for down) + version = VerNum(version) + cur = schema.version + if upgrade is not None: + if upgrade: + direction = cur <= version + else: + direction = cur >= version + if not direction: + raise exceptions.KnownError(err % (cur, version)) + return version diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/cfgparse.py b/rhodecode/lib/dbmigrate/migrate/versioning/cfgparse.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/cfgparse.py @@ -0,0 +1,27 @@ +""" + Configuration parser module. +""" + +from ConfigParser import ConfigParser + +from migrate.versioning.config import * +from migrate.versioning import pathed + + +class Parser(ConfigParser): + """A project configuration file.""" + + def to_dict(self, sections=None): + """It's easier to access config values like dictionaries""" + return self._sections + + +class Config(pathed.Pathed, Parser): + """Configuration class.""" + + def __init__(self, path, *p, **k): + """Confirm the config file exists; read it.""" + self.require_found(path) + pathed.Pathed.__init__(self, path) + Parser.__init__(self, *p, **k) + self.read(path) diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/config.py b/rhodecode/lib/dbmigrate/migrate/versioning/config.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/config.py @@ -0,0 +1,14 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +from sqlalchemy.util import OrderedDict + + +__all__ = ['databases', 'operations'] + +databases = ('sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird') + +# Map operation names to function names +operations = OrderedDict() +operations['upgrade'] = 'upgrade' +operations['downgrade'] = 'downgrade' diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/genmodel.py b/rhodecode/lib/dbmigrate/migrate/versioning/genmodel.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/genmodel.py @@ -0,0 +1,254 @@ +""" + Code to generate a Python model from a database or differences + between a model and database. + + Some of this is borrowed heavily from the AutoCode project at: + http://code.google.com/p/sqlautocode/ +""" + +import sys +import logging + +import sqlalchemy + +import migrate +import migrate.changeset + + +log = logging.getLogger(__name__) +HEADER = """ +## File autogenerated by genmodel.py + +from sqlalchemy import * +meta = MetaData() +""" + +DECLARATIVE_HEADER = """ +## File autogenerated by genmodel.py + +from sqlalchemy import * +from sqlalchemy.ext import declarative + +Base = declarative.declarative_base() +""" + + +class ModelGenerator(object): + + def __init__(self, diff, engine, declarative=False): + self.diff = diff + self.engine = engine + self.declarative = declarative + + def column_repr(self, col): + kwarg = [] + if col.key != col.name: + kwarg.append('key') + if col.primary_key: + col.primary_key = True # otherwise it dumps it as 1 + kwarg.append('primary_key') + if not col.nullable: + kwarg.append('nullable') + if col.onupdate: + kwarg.append('onupdate') + if col.default: + if col.primary_key: + # I found that PostgreSQL automatically creates a + # default value for the sequence, but let's not show + # that. + pass + else: + kwarg.append('default') + ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg) + + # crs: not sure if this is good idea, but it gets rid of extra + # u'' + name = col.name.encode('utf8') + + type_ = col.type + for cls in col.type.__class__.__mro__: + if cls.__module__ == 'sqlalchemy.types' and \ + not cls.__name__.isupper(): + if cls is not type_.__class__: + type_ = cls() + break + + data = { + 'name': name, + 'type': type_, + 'constraints': ', '.join([repr(cn) for cn in col.constraints]), + 'args': ks and ks or ''} + + if data['constraints']: + if data['args']: + data['args'] = ',' + data['args'] + + if data['constraints'] or data['args']: + data['maybeComma'] = ',' + else: + data['maybeComma'] = '' + + commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data + commonStuff = commonStuff.strip() + data['commonStuff'] = commonStuff + if self.declarative: + return """%(name)s = Column(%(type)r%(commonStuff)s""" % data + else: + return """Column(%(name)r, %(type)r%(commonStuff)s""" % data + + def getTableDefn(self, table): + out = [] + tableName = table.name + if self.declarative: + out.append("class %(table)s(Base):" % {'table': tableName}) + out.append(" __tablename__ = '%(table)s'" % {'table': tableName}) + for col in table.columns: + out.append(" %s" % self.column_repr(col)) + else: + out.append("%(table)s = Table('%(table)s', meta," % \ + {'table': tableName}) + for col in table.columns: + out.append(" %s," % self.column_repr(col)) + out.append(")") + return out + + def _get_tables(self,missingA=False,missingB=False,modified=False): + to_process = [] + for bool_,names,metadata in ( + (missingA,self.diff.tables_missing_from_A,self.diff.metadataB), + (missingB,self.diff.tables_missing_from_B,self.diff.metadataA), + (modified,self.diff.tables_different,self.diff.metadataA), + ): + if bool_: + for name in names: + yield metadata.tables.get(name) + + def toPython(self): + """Assume database is current and model is empty.""" + out = [] + if self.declarative: + out.append(DECLARATIVE_HEADER) + else: + out.append(HEADER) + out.append("") + for table in self._get_tables(missingA=True): + out.extend(self.getTableDefn(table)) + out.append("") + return '\n'.join(out) + + def toUpgradeDowngradePython(self, indent=' '): + ''' Assume model is most current and database is out-of-date. ''' + decls = ['from migrate.changeset import schema', + 'meta = MetaData()'] + for table in self._get_tables( + missingA=True,missingB=True,modified=True + ): + decls.extend(self.getTableDefn(table)) + + upgradeCommands, downgradeCommands = [], [] + for tableName in self.diff.tables_missing_from_A: + upgradeCommands.append("%(table)s.drop()" % {'table': tableName}) + downgradeCommands.append("%(table)s.create()" % \ + {'table': tableName}) + for tableName in self.diff.tables_missing_from_B: + upgradeCommands.append("%(table)s.create()" % {'table': tableName}) + downgradeCommands.append("%(table)s.drop()" % {'table': tableName}) + + for tableName in self.diff.tables_different: + dbTable = self.diff.metadataB.tables[tableName] + missingInDatabase, missingInModel, diffDecl = \ + self.diff.colDiffs[tableName] + for col in missingInDatabase: + upgradeCommands.append('%s.columns[%r].create()' % ( + modelTable, col.name)) + downgradeCommands.append('%s.columns[%r].drop()' % ( + modelTable, col.name)) + for col in missingInModel: + upgradeCommands.append('%s.columns[%r].drop()' % ( + modelTable, col.name)) + downgradeCommands.append('%s.columns[%r].create()' % ( + modelTable, col.name)) + for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl: + upgradeCommands.append( + 'assert False, "Can\'t alter columns: %s:%s=>%s"', + modelTable, modelCol.name, databaseCol.name) + downgradeCommands.append( + 'assert False, "Can\'t alter columns: %s:%s=>%s"', + modelTable, modelCol.name, databaseCol.name) + pre_command = ' meta.bind = migrate_engine' + + return ( + '\n'.join(decls), + '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]), + '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands])) + + def _db_can_handle_this_change(self,td): + if (td.columns_missing_from_B + and not td.columns_missing_from_A + and not td.columns_different): + # Even sqlite can handle this. + return True + else: + return not self.engine.url.drivername.startswith('sqlite') + + def applyModel(self): + """Apply model to current database.""" + + meta = sqlalchemy.MetaData(self.engine) + + for table in self._get_tables(missingA=True): + table = table.tometadata(meta) + table.drop() + for table in self._get_tables(missingB=True): + table = table.tometadata(meta) + table.create() + for modelTable in self._get_tables(modified=True): + tableName = modelTable.name + modelTable = modelTable.tometadata(meta) + dbTable = self.diff.metadataB.tables[tableName] + + td = self.diff.tables_different[tableName] + + if self._db_can_handle_this_change(td): + + for col in td.columns_missing_from_B: + modelTable.columns[col].create() + for col in td.columns_missing_from_A: + dbTable.columns[col].drop() + # XXX handle column changes here. + else: + # Sqlite doesn't support drop column, so you have to + # do more: create temp table, copy data to it, drop + # old table, create new table, copy data back. + # + # I wonder if this is guaranteed to be unique? + tempName = '_temp_%s' % modelTable.name + + def getCopyStatement(): + preparer = self.engine.dialect.preparer + commonCols = [] + for modelCol in modelTable.columns: + if modelCol.name in dbTable.columns: + commonCols.append(modelCol.name) + commonColsStr = ', '.join(commonCols) + return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \ + (tableName, commonColsStr, commonColsStr, tempName) + + # Move the data in one transaction, so that we don't + # leave the database in a nasty state. + connection = self.engine.connect() + trans = connection.begin() + try: + connection.execute( + 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \ + (tempName, modelTable.name)) + # make sure the drop takes place inside our + # transaction with the bind parameter + modelTable.drop(bind=connection) + modelTable.create(bind=connection) + connection.execute(getCopyStatement()) + connection.execute('DROP TABLE %s' % tempName) + trans.commit() + except: + trans.rollback() + raise diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/migrate_repository.py b/rhodecode/lib/dbmigrate/migrate/versioning/migrate_repository.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/migrate_repository.py @@ -0,0 +1,100 @@ +""" + Script to migrate repository from sqlalchemy <= 0.4.4 to the new + repository schema. This shouldn't use any other migrate modules, so + that it can work in any version. +""" + +import os +import sys +import logging + +log = logging.getLogger(__name__) + + +def usage(): + """Gives usage information.""" + print """Usage: %(prog)s repository-to-migrate + + Upgrade your repository to the new flat format. + + NOTE: You should probably make a backup before running this. + """ % {'prog': sys.argv[0]} + + sys.exit(1) + + +def delete_file(filepath): + """Deletes a file and prints a message.""" + log.info('Deleting file: %s' % filepath) + os.remove(filepath) + + +def move_file(src, tgt): + """Moves a file and prints a message.""" + log.info('Moving file %s to %s' % (src, tgt)) + if os.path.exists(tgt): + raise Exception( + 'Cannot move file %s because target %s already exists' % \ + (src, tgt)) + os.rename(src, tgt) + + +def delete_directory(dirpath): + """Delete a directory and print a message.""" + log.info('Deleting directory: %s' % dirpath) + os.rmdir(dirpath) + + +def migrate_repository(repos): + """Does the actual migration to the new repository format.""" + log.info('Migrating repository at: %s to new format' % repos) + versions = '%s/versions' % repos + dirs = os.listdir(versions) + # Only use int's in list. + numdirs = [int(dirname) for dirname in dirs if dirname.isdigit()] + numdirs.sort() # Sort list. + for dirname in numdirs: + origdir = '%s/%s' % (versions, dirname) + log.info('Working on directory: %s' % origdir) + files = os.listdir(origdir) + files.sort() + for filename in files: + # Delete compiled Python files. + if filename.endswith('.pyc') or filename.endswith('.pyo'): + delete_file('%s/%s' % (origdir, filename)) + + # Delete empty __init__.py files. + origfile = '%s/__init__.py' % origdir + if os.path.exists(origfile) and len(open(origfile).read()) == 0: + delete_file(origfile) + + # Move sql upgrade scripts. + if filename.endswith('.sql'): + version, dbms, operation = filename.split('.', 3)[0:3] + origfile = '%s/%s' % (origdir, filename) + # For instance: 2.postgres.upgrade.sql -> + # 002_postgres_upgrade.sql + tgtfile = '%s/%03d_%s_%s.sql' % ( + versions, int(version), dbms, operation) + move_file(origfile, tgtfile) + + # Move Python upgrade script. + pyfile = '%s.py' % dirname + pyfilepath = '%s/%s' % (origdir, pyfile) + if os.path.exists(pyfilepath): + tgtfile = '%s/%03d.py' % (versions, int(dirname)) + move_file(pyfilepath, tgtfile) + + # Try to remove directory. Will fail if it's not empty. + delete_directory(origdir) + + +def main(): + """Main function to be called when using this script.""" + if len(sys.argv) != 2: + usage() + migrate_repository(sys.argv[1]) + + +if __name__ == '__main__': + main() diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/pathed.py b/rhodecode/lib/dbmigrate/migrate/versioning/pathed.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/pathed.py @@ -0,0 +1,75 @@ +""" + A path/directory class. +""" + +import os +import shutil +import logging + +from migrate import exceptions +from migrate.versioning.config import * +from migrate.versioning.util import KeyedInstance + + +log = logging.getLogger(__name__) + +class Pathed(KeyedInstance): + """ + A class associated with a path/directory tree. + + Only one instance of this class may exist for a particular file; + __new__ will return an existing instance if possible + """ + parent = None + + @classmethod + def _key(cls, path): + return str(path) + + def __init__(self, path): + self.path = path + if self.__class__.parent is not None: + self._init_parent(path) + + def _init_parent(self, path): + """Try to initialize this object's parent, if it has one""" + parent_path = self.__class__._parent_path(path) + self.parent = self.__class__.parent(parent_path) + log.debug("Getting parent %r:%r" % (self.__class__.parent, parent_path)) + self.parent._init_child(path, self) + + def _init_child(self, child, path): + """Run when a child of this object is initialized. + + Parameters: the child object; the path to this object (its + parent) + """ + + @classmethod + def _parent_path(cls, path): + """ + Fetch the path of this object's parent from this object's path. + """ + # os.path.dirname(), but strip directories like files (like + # unix basename) + # + # Treat directories like files... + if path[-1] == '/': + path = path[:-1] + ret = os.path.dirname(path) + return ret + + @classmethod + def require_notfound(cls, path): + """Ensures a given path does not already exist""" + if os.path.exists(path): + raise exceptions.PathFoundError(path) + + @classmethod + def require_found(cls, path): + """Ensures a given path already exists""" + if not os.path.exists(path): + raise exceptions.PathNotFoundError(path) + + def __str__(self): + return self.path diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/repository.py b/rhodecode/lib/dbmigrate/migrate/versioning/repository.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/repository.py @@ -0,0 +1,231 @@ +""" + SQLAlchemy migrate repository management. +""" +import os +import shutil +import string +import logging + +from pkg_resources import resource_filename +from tempita import Template as TempitaTemplate + +from migrate import exceptions +from migrate.versioning import version, pathed, cfgparse +from migrate.versioning.template import Template +from migrate.versioning.config import * + + +log = logging.getLogger(__name__) + +class Changeset(dict): + """A collection of changes to be applied to a database. + + Changesets are bound to a repository and manage a set of + scripts from that repository. + + Behaves like a dict, for the most part. Keys are ordered based on step value. + """ + + def __init__(self, start, *changes, **k): + """ + Give a start version; step must be explicitly stated. + """ + self.step = k.pop('step', 1) + self.start = version.VerNum(start) + self.end = self.start + for change in changes: + self.add(change) + + def __iter__(self): + return iter(self.items()) + + def keys(self): + """ + In a series of upgrades x -> y, keys are version x. Sorted. + """ + ret = super(Changeset, self).keys() + # Reverse order if downgrading + ret.sort(reverse=(self.step < 1)) + return ret + + def values(self): + return [self[k] for k in self.keys()] + + def items(self): + return zip(self.keys(), self.values()) + + def add(self, change): + """Add new change to changeset""" + key = self.end + self.end += self.step + self[key] = change + + def run(self, *p, **k): + """Run the changeset scripts""" + for version, script in self: + script.run(*p, **k) + + +class Repository(pathed.Pathed): + """A project's change script repository""" + + _config = 'migrate.cfg' + _versions = 'versions' + + def __init__(self, path): + log.debug('Loading repository %s...' % path) + self.verify(path) + super(Repository, self).__init__(path) + self.config = cfgparse.Config(os.path.join(self.path, self._config)) + self.versions = version.Collection(os.path.join(self.path, + self._versions)) + log.debug('Repository %s loaded successfully' % path) + log.debug('Config: %r' % self.config.to_dict()) + + @classmethod + def verify(cls, path): + """ + Ensure the target path is a valid repository. + + :raises: :exc:`InvalidRepositoryError ` + """ + # Ensure the existence of required files + try: + cls.require_found(path) + cls.require_found(os.path.join(path, cls._config)) + cls.require_found(os.path.join(path, cls._versions)) + except exceptions.PathNotFoundError, e: + raise exceptions.InvalidRepositoryError(path) + + @classmethod + def prepare_config(cls, tmpl_dir, name, options=None): + """ + Prepare a project configuration file for a new project. + + :param tmpl_dir: Path to Repository template + :param config_file: Name of the config file in Repository template + :param name: Repository name + :type tmpl_dir: string + :type config_file: string + :type name: string + :returns: Populated config file + """ + if options is None: + options = {} + options.setdefault('version_table', 'migrate_version') + options.setdefault('repository_id', name) + options.setdefault('required_dbs', []) + + tmpl = open(os.path.join(tmpl_dir, cls._config)).read() + ret = TempitaTemplate(tmpl).substitute(options) + + # cleanup + del options['__template_name__'] + + return ret + + @classmethod + def create(cls, path, name, **opts): + """Create a repository at a specified path""" + cls.require_notfound(path) + theme = opts.pop('templates_theme', None) + t_path = opts.pop('templates_path', None) + + # Create repository + tmpl_dir = Template(t_path).get_repository(theme=theme) + shutil.copytree(tmpl_dir, path) + + # Edit config defaults + config_text = cls.prepare_config(tmpl_dir, name, options=opts) + fd = open(os.path.join(path, cls._config), 'w') + fd.write(config_text) + fd.close() + + opts['repository_name'] = name + + # Create a management script + manager = os.path.join(path, 'manage.py') + Repository.create_manage_file(manager, templates_theme=theme, + templates_path=t_path, **opts) + + return cls(path) + + def create_script(self, description, **k): + """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`""" + self.versions.create_new_python_version(description, **k) + + def create_script_sql(self, database, **k): + """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`""" + self.versions.create_new_sql_version(database, **k) + + @property + def latest(self): + """API to :attr:`migrate.versioning.version.Collection.latest`""" + return self.versions.latest + + @property + def version_table(self): + """Returns version_table name specified in config""" + return self.config.get('db_settings', 'version_table') + + @property + def id(self): + """Returns repository id specified in config""" + return self.config.get('db_settings', 'repository_id') + + def version(self, *p, **k): + """API to :attr:`migrate.versioning.version.Collection.version`""" + return self.versions.version(*p, **k) + + @classmethod + def clear(cls): + # TODO: deletes repo + super(Repository, cls).clear() + version.Collection.clear() + + def changeset(self, database, start, end=None): + """Create a changeset to migrate this database from ver. start to end/latest. + + :param database: name of database to generate changeset + :param start: version to start at + :param end: version to end at (latest if None given) + :type database: string + :type start: int + :type end: int + :returns: :class:`Changeset instance ` + """ + start = version.VerNum(start) + + if end is None: + end = self.latest + else: + end = version.VerNum(end) + + if start <= end: + step = 1 + range_mod = 1 + op = 'upgrade' + else: + step = -1 + range_mod = 0 + op = 'downgrade' + + versions = range(start + range_mod, end + range_mod, step) + changes = [self.version(v).script(database, op) for v in versions] + ret = Changeset(start, step=step, *changes) + return ret + + @classmethod + def create_manage_file(cls, file_, **opts): + """Create a project management script (manage.py) + + :param file_: Destination file to be written + :param opts: Options that are passed to :func:`migrate.versioning.shell.main` + """ + mng_file = Template(opts.pop('templates_path', None))\ + .get_manage(theme=opts.pop('templates_theme', None)) + + tmpl = open(mng_file).read() + fd = open(file_, 'w') + fd.write(TempitaTemplate(tmpl).substitute(opts)) + fd.close() diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/schema.py b/rhodecode/lib/dbmigrate/migrate/versioning/schema.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/schema.py @@ -0,0 +1,213 @@ +""" + Database schema version management. +""" +import sys +import logging + +from sqlalchemy import (Table, Column, MetaData, String, Text, Integer, + create_engine) +from sqlalchemy.sql import and_ +from sqlalchemy import exceptions as sa_exceptions +from sqlalchemy.sql import bindparam + +from migrate import exceptions +from migrate.versioning import genmodel, schemadiff +from migrate.versioning.repository import Repository +from migrate.versioning.util import load_model +from migrate.versioning.version import VerNum + + +log = logging.getLogger(__name__) + +class ControlledSchema(object): + """A database under version control""" + + def __init__(self, engine, repository): + if isinstance(repository, basestring): + repository = Repository(repository) + self.engine = engine + self.repository = repository + self.meta = MetaData(engine) + self.load() + + def __eq__(self, other): + """Compare two schemas by repositories and versions""" + return (self.repository is other.repository \ + and self.version == other.version) + + def load(self): + """Load controlled schema version info from DB""" + tname = self.repository.version_table + try: + if not hasattr(self, 'table') or self.table is None: + self.table = Table(tname, self.meta, autoload=True) + + result = self.engine.execute(self.table.select( + self.table.c.repository_id == str(self.repository.id))) + + data = list(result)[0] + except: + cls, exc, tb = sys.exc_info() + raise exceptions.DatabaseNotControlledError, exc.__str__(), tb + + self.version = data['version'] + return data + + def drop(self): + """ + Remove version control from a database. + """ + try: + self.table.drop() + except (sa_exceptions.SQLError): + raise exceptions.DatabaseNotControlledError(str(self.table)) + + def changeset(self, version=None): + """API to Changeset creation. + + Uses self.version for start version and engine.name + to get database name. + """ + database = self.engine.name + start_ver = self.version + changeset = self.repository.changeset(database, start_ver, version) + return changeset + + def runchange(self, ver, change, step): + startver = ver + endver = ver + step + # Current database version must be correct! Don't run if corrupt! + if self.version != startver: + raise exceptions.InvalidVersionError("%s is not %s" % \ + (self.version, startver)) + # Run the change + change.run(self.engine, step) + + # Update/refresh database version + self.update_repository_table(startver, endver) + self.load() + + def update_repository_table(self, startver, endver): + """Update version_table with new information""" + update = self.table.update(and_(self.table.c.version == int(startver), + self.table.c.repository_id == str(self.repository.id))) + self.engine.execute(update, version=int(endver)) + + def upgrade(self, version=None): + """ + Upgrade (or downgrade) to a specified version, or latest version. + """ + changeset = self.changeset(version) + for ver, change in changeset: + self.runchange(ver, change, changeset.step) + + def update_db_from_model(self, model): + """ + Modify the database to match the structure of the current Python model. + """ + model = load_model(model) + + diff = schemadiff.getDiffOfModelAgainstDatabase( + model, self.engine, excludeTables=[self.repository.version_table] + ) + genmodel.ModelGenerator(diff,self.engine).applyModel() + + self.update_repository_table(self.version, int(self.repository.latest)) + + self.load() + + @classmethod + def create(cls, engine, repository, version=None): + """ + Declare a database to be under a repository's version control. + + :raises: :exc:`DatabaseAlreadyControlledError` + :returns: :class:`ControlledSchema` + """ + # Confirm that the version # is valid: positive, integer, + # exists in repos + if isinstance(repository, basestring): + repository = Repository(repository) + version = cls._validate_version(repository, version) + table = cls._create_table_version(engine, repository, version) + # TODO: history table + # Load repository information and return + return cls(engine, repository) + + @classmethod + def _validate_version(cls, repository, version): + """ + Ensures this is a valid version number for this repository. + + :raises: :exc:`InvalidVersionError` if invalid + :return: valid version number + """ + if version is None: + version = 0 + try: + version = VerNum(version) # raises valueerror + if version < 0 or version > repository.latest: + raise ValueError() + except ValueError: + raise exceptions.InvalidVersionError(version) + return version + + @classmethod + def _create_table_version(cls, engine, repository, version): + """ + Creates the versioning table in a database. + + :raises: :exc:`DatabaseAlreadyControlledError` + """ + # Create tables + tname = repository.version_table + meta = MetaData(engine) + + table = Table( + tname, meta, + Column('repository_id', String(250), primary_key=True), + Column('repository_path', Text), + Column('version', Integer), ) + + # there can be multiple repositories/schemas in the same db + if not table.exists(): + table.create() + + # test for existing repository_id + s = table.select(table.c.repository_id == bindparam("repository_id")) + result = engine.execute(s, repository_id=repository.id) + if result.fetchone(): + raise exceptions.DatabaseAlreadyControlledError + + # Insert data + engine.execute(table.insert().values( + repository_id=repository.id, + repository_path=repository.path, + version=int(version))) + return table + + @classmethod + def compare_model_to_db(cls, engine, model, repository): + """ + Compare the current model against the current database. + """ + if isinstance(repository, basestring): + repository = Repository(repository) + model = load_model(model) + + diff = schemadiff.getDiffOfModelAgainstDatabase( + model, engine, excludeTables=[repository.version_table]) + return diff + + @classmethod + def create_model(cls, engine, repository, declarative=False): + """ + Dump the current database as a Python model. + """ + if isinstance(repository, basestring): + repository = Repository(repository) + + diff = schemadiff.getDiffOfModelAgainstDatabase( + MetaData(), engine, excludeTables=[repository.version_table] + ) + return genmodel.ModelGenerator(diff, engine, declarative).toPython() diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/schemadiff.py b/rhodecode/lib/dbmigrate/migrate/versioning/schemadiff.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/schemadiff.py @@ -0,0 +1,285 @@ +""" + Schema differencing support. +""" + +import logging +import sqlalchemy + +from migrate.changeset import SQLA_06 +from sqlalchemy.types import Float + +log = logging.getLogger(__name__) + +def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None): + """ + Return differences of model against database. + + :return: object which will evaluate to :keyword:`True` if there \ + are differences else :keyword:`False`. + """ + return SchemaDiff(metadata, + sqlalchemy.MetaData(engine, reflect=True), + labelA='model', + labelB='database', + excludeTables=excludeTables) + + +def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None): + """ + Return differences of model against another model. + + :return: object which will evaluate to :keyword:`True` if there \ + are differences else :keyword:`False`. + """ + return SchemaDiff(metadataA, metadataB, excludeTables) + + +class ColDiff(object): + """ + Container for differences in one :class:`~sqlalchemy.schema.Column` + between two :class:`~sqlalchemy.schema.Table` instances, ``A`` + and ``B``. + + .. attribute:: col_A + + The :class:`~sqlalchemy.schema.Column` object for A. + + .. attribute:: col_B + + The :class:`~sqlalchemy.schema.Column` object for B. + + .. attribute:: type_A + + The most generic type of the :class:`~sqlalchemy.schema.Column` + object in A. + + .. attribute:: type_B + + The most generic type of the :class:`~sqlalchemy.schema.Column` + object in A. + + """ + + diff = False + + def __init__(self,col_A,col_B): + self.col_A = col_A + self.col_B = col_B + + self.type_A = col_A.type + self.type_B = col_B.type + + self.affinity_A = self.type_A._type_affinity + self.affinity_B = self.type_B._type_affinity + + if self.affinity_A is not self.affinity_B: + self.diff = True + return + + if isinstance(self.type_A,Float) or isinstance(self.type_B,Float): + if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)): + self.diff=True + return + + for attr in ('precision','scale','length'): + A = getattr(self.type_A,attr,None) + B = getattr(self.type_B,attr,None) + if not (A is None or B is None) and A!=B: + self.diff=True + return + + def __nonzero__(self): + return self.diff + +class TableDiff(object): + """ + Container for differences in one :class:`~sqlalchemy.schema.Table` + between two :class:`~sqlalchemy.schema.MetaData` instances, ``A`` + and ``B``. + + .. attribute:: columns_missing_from_A + + A sequence of column names that were found in B but weren't in + A. + + .. attribute:: columns_missing_from_B + + A sequence of column names that were found in A but weren't in + B. + + .. attribute:: columns_different + + A dictionary containing information about columns that were + found to be different. + It maps column names to a :class:`ColDiff` objects describing the + differences found. + """ + __slots__ = ( + 'columns_missing_from_A', + 'columns_missing_from_B', + 'columns_different', + ) + + def __nonzero__(self): + return bool( + self.columns_missing_from_A or + self.columns_missing_from_B or + self.columns_different + ) + +class SchemaDiff(object): + """ + Compute the difference between two :class:`~sqlalchemy.schema.MetaData` + objects. + + The string representation of a :class:`SchemaDiff` will summarise + the changes found between the two + :class:`~sqlalchemy.schema.MetaData` objects. + + The length of a :class:`SchemaDiff` will give the number of + changes found, enabling it to be used much like a boolean in + expressions. + + :param metadataA: + First :class:`~sqlalchemy.schema.MetaData` to compare. + + :param metadataB: + Second :class:`~sqlalchemy.schema.MetaData` to compare. + + :param labelA: + The label to use in messages about the first + :class:`~sqlalchemy.schema.MetaData`. + + :param labelB: + The label to use in messages about the second + :class:`~sqlalchemy.schema.MetaData`. + + :param excludeTables: + A sequence of table names to exclude. + + .. attribute:: tables_missing_from_A + + A sequence of table names that were found in B but weren't in + A. + + .. attribute:: tables_missing_from_B + + A sequence of table names that were found in A but weren't in + B. + + .. attribute:: tables_different + + A dictionary containing information about tables that were found + to be different. + It maps table names to a :class:`TableDiff` objects describing the + differences found. + """ + + def __init__(self, + metadataA, metadataB, + labelA='metadataA', + labelB='metadataB', + excludeTables=None): + + self.metadataA, self.metadataB = metadataA, metadataB + self.labelA, self.labelB = labelA, labelB + self.label_width = max(len(labelA),len(labelB)) + excludeTables = set(excludeTables or []) + + A_table_names = set(metadataA.tables.keys()) + B_table_names = set(metadataB.tables.keys()) + + self.tables_missing_from_A = sorted( + B_table_names - A_table_names - excludeTables + ) + self.tables_missing_from_B = sorted( + A_table_names - B_table_names - excludeTables + ) + + self.tables_different = {} + for table_name in A_table_names.intersection(B_table_names): + + td = TableDiff() + + A_table = metadataA.tables[table_name] + B_table = metadataB.tables[table_name] + + A_column_names = set(A_table.columns.keys()) + B_column_names = set(B_table.columns.keys()) + + td.columns_missing_from_A = sorted( + B_column_names - A_column_names + ) + + td.columns_missing_from_B = sorted( + A_column_names - B_column_names + ) + + td.columns_different = {} + + for col_name in A_column_names.intersection(B_column_names): + + cd = ColDiff( + A_table.columns.get(col_name), + B_table.columns.get(col_name) + ) + + if cd: + td.columns_different[col_name]=cd + + # XXX - index and constraint differences should + # be checked for here + + if td: + self.tables_different[table_name]=td + + def __str__(self): + ''' Summarize differences. ''' + out = [] + column_template =' %%%is: %%r' % self.label_width + + for names,label in ( + (self.tables_missing_from_A,self.labelA), + (self.tables_missing_from_B,self.labelB), + ): + if names: + out.append( + ' tables missing from %s: %s' % ( + label,', '.join(sorted(names)) + ) + ) + + for name,td in sorted(self.tables_different.items()): + out.append( + ' table with differences: %s' % name + ) + for names,label in ( + (td.columns_missing_from_A,self.labelA), + (td.columns_missing_from_B,self.labelB), + ): + if names: + out.append( + ' %s missing these columns: %s' % ( + label,', '.join(sorted(names)) + ) + ) + for name,cd in td.columns_different.items(): + out.append(' column with differences: %s' % name) + out.append(column_template % (self.labelA,cd.col_A)) + out.append(column_template % (self.labelB,cd.col_B)) + + if out: + out.insert(0, 'Schema diffs:') + return '\n'.join(out) + else: + return 'No schema diffs' + + def __len__(self): + """ + Used in bool evaluation, return of 0 means no diffs. + """ + return ( + len(self.tables_missing_from_A) + + len(self.tables_missing_from_B) + + len(self.tables_different) + ) diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/script/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/script/__init__.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/script/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from migrate.versioning.script.base import BaseScript +from migrate.versioning.script.py import PythonScript +from migrate.versioning.script.sql import SqlScript diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/script/base.py b/rhodecode/lib/dbmigrate/migrate/versioning/script/base.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/script/base.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import logging + +from migrate import exceptions +from migrate.versioning.config import operations +from migrate.versioning import pathed + + +log = logging.getLogger(__name__) + +class BaseScript(pathed.Pathed): + """Base class for other types of scripts. + All scripts have the following properties: + + source (script.source()) + The source code of the script + version (script.version()) + The version number of the script + operations (script.operations()) + The operations defined by the script: upgrade(), downgrade() or both. + Returns a tuple of operations. + Can also check for an operation with ex. script.operation(Script.ops.up) + """ # TODO: sphinxfy this and implement it correctly + + def __init__(self, path): + log.debug('Loading script %s...' % path) + self.verify(path) + super(BaseScript, self).__init__(path) + log.debug('Script %s loaded successfully' % path) + + @classmethod + def verify(cls, path): + """Ensure this is a valid script + This version simply ensures the script file's existence + + :raises: :exc:`InvalidScriptError ` + """ + try: + cls.require_found(path) + except: + raise exceptions.InvalidScriptError(path) + + def source(self): + """:returns: source code of the script. + :rtype: string + """ + fd = open(self.path) + ret = fd.read() + fd.close() + return ret + + def run(self, engine): + """Core of each BaseScript subclass. + This method executes the script. + """ + raise NotImplementedError() diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/script/py.py b/rhodecode/lib/dbmigrate/migrate/versioning/script/py.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/script/py.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import shutil +import warnings +import logging +from StringIO import StringIO + +import migrate +from migrate.versioning import genmodel, schemadiff +from migrate.versioning.config import operations +from migrate.versioning.template import Template +from migrate.versioning.script import base +from migrate.versioning.util import import_path, load_model, with_engine +from migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError + +log = logging.getLogger(__name__) +__all__ = ['PythonScript'] + + +class PythonScript(base.BaseScript): + """Base for Python scripts""" + + @classmethod + def create(cls, path, **opts): + """Create an empty migration script at specified path + + :returns: :class:`PythonScript instance `""" + cls.require_notfound(path) + + src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None)) + shutil.copy(src, path) + + return cls(path) + + @classmethod + def make_update_script_for_model(cls, engine, oldmodel, + model, repository, **opts): + """Create a migration script based on difference between two SA models. + + :param repository: path to migrate repository + :param oldmodel: dotted.module.name:SAClass or SAClass object + :param model: dotted.module.name:SAClass or SAClass object + :param engine: SQLAlchemy engine + :type repository: string or :class:`Repository instance ` + :type oldmodel: string or Class + :type model: string or Class + :type engine: Engine instance + :returns: Upgrade / Downgrade script + :rtype: string + """ + + if isinstance(repository, basestring): + # oh dear, an import cycle! + from migrate.versioning.repository import Repository + repository = Repository(repository) + + oldmodel = load_model(oldmodel) + model = load_model(model) + + # Compute differences. + diff = schemadiff.getDiffOfModelAgainstModel( + oldmodel, + model, + excludeTables=[repository.version_table]) + # TODO: diff can be False (there is no difference?) + decls, upgradeCommands, downgradeCommands = \ + genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython() + + # Store differences into file. + src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None)) + f = open(src) + contents = f.read() + f.close() + + # generate source + search = 'def upgrade(migrate_engine):' + contents = contents.replace(search, '\n\n'.join((decls, search)), 1) + if upgradeCommands: + contents = contents.replace(' pass', upgradeCommands, 1) + if downgradeCommands: + contents = contents.replace(' pass', downgradeCommands, 1) + return contents + + @classmethod + def verify_module(cls, path): + """Ensure path is a valid script + + :param path: Script location + :type path: string + :raises: :exc:`InvalidScriptError ` + :returns: Python module + """ + # Try to import and get the upgrade() func + module = import_path(path) + try: + assert callable(module.upgrade) + except Exception, e: + raise InvalidScriptError(path + ': %s' % str(e)) + return module + + def preview_sql(self, url, step, **args): + """Mocks SQLAlchemy Engine to store all executed calls in a string + and runs :meth:`PythonScript.run ` + + :returns: SQL file + """ + buf = StringIO() + args['engine_arg_strategy'] = 'mock' + args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p) + + @with_engine + def go(url, step, **kw): + engine = kw.pop('engine') + self.run(engine, step) + return buf.getvalue() + + return go(url, step, **args) + + def run(self, engine, step): + """Core method of Script file. + Exectues :func:`update` or :func:`downgrade` functions + + :param engine: SQLAlchemy Engine + :param step: Operation to run + :type engine: string + :type step: int + """ + if step > 0: + op = 'upgrade' + elif step < 0: + op = 'downgrade' + else: + raise ScriptError("%d is not a valid step" % step) + + funcname = base.operations[op] + script_func = self._func(funcname) + + try: + script_func(engine) + except TypeError: + warnings.warn("upgrade/downgrade functions must accept engine" + " parameter (since version > 0.5.4)", MigrateDeprecationWarning) + raise + + @property + def module(self): + """Calls :meth:`migrate.versioning.script.py.verify_module` + and returns it. + """ + if not hasattr(self, '_module'): + self._module = self.verify_module(self.path) + return self._module + + def _func(self, funcname): + if not hasattr(self.module, funcname): + msg = "Function '%s' is not defined in this script" + raise ScriptError(msg % funcname) + return getattr(self.module, funcname) diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/script/sql.py b/rhodecode/lib/dbmigrate/migrate/versioning/script/sql.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/script/sql.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import logging +import shutil + +from migrate.versioning.script import base +from migrate.versioning.template import Template + + +log = logging.getLogger(__name__) + +class SqlScript(base.BaseScript): + """A file containing plain SQL statements.""" + + @classmethod + def create(cls, path, **opts): + """Create an empty migration script at specified path + + :returns: :class:`SqlScript instance `""" + cls.require_notfound(path) + + src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None)) + shutil.copy(src, path) + return cls(path) + + # TODO: why is step parameter even here? + def run(self, engine, step=None, executemany=True): + """Runs SQL script through raw dbapi execute call""" + text = self.source() + # Don't rely on SA's autocommit here + # (SA uses .startswith to check if a commit is needed. What if script + # starts with a comment?) + conn = engine.connect() + try: + trans = conn.begin() + try: + # HACK: SQLite doesn't allow multiple statements through + # its execute() method, but it provides executescript() instead + dbapi = conn.engine.raw_connection() + if executemany and getattr(dbapi, 'executescript', None): + dbapi.executescript(text) + else: + conn.execute(text) + trans.commit() + except: + trans.rollback() + raise + finally: + conn.close() diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/shell.py b/rhodecode/lib/dbmigrate/migrate/versioning/shell.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/shell.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""The migrate command-line tool.""" + +import sys +import inspect +import logging +from optparse import OptionParser, BadOptionError + +from migrate import exceptions +from migrate.versioning import api +from migrate.versioning.config import * +from migrate.versioning.util import asbool + + +alias = dict( + s=api.script, + vc=api.version_control, + dbv=api.db_version, + v=api.version, +) + +def alias_setup(): + global alias + for key, val in alias.iteritems(): + setattr(api, key, val) +alias_setup() + + +class PassiveOptionParser(OptionParser): + + def _process_args(self, largs, rargs, values): + """little hack to support all --some_option=value parameters""" + + while rargs: + arg = rargs[0] + if arg == "--": + del rargs[0] + return + elif arg[0:2] == "--": + # if parser does not know about the option + # pass it along (make it anonymous) + try: + opt = arg.split('=', 1)[0] + self._match_long_opt(opt) + except BadOptionError: + largs.append(arg) + del rargs[0] + else: + self._process_long_opt(rargs, values) + elif arg[:1] == "-" and len(arg) > 1: + self._process_short_opts(rargs, values) + elif self.allow_interspersed_args: + largs.append(arg) + del rargs[0] + +def main(argv=None, **kwargs): + """Shell interface to :mod:`migrate.versioning.api`. + + kwargs are default options that can be overriden with passing + --some_option as command line option + + :param disable_logging: Let migrate configure logging + :type disable_logging: bool + """ + if argv is not None: + argv = argv + else: + argv = list(sys.argv[1:]) + commands = list(api.__all__) + commands.sort() + + usage = """%%prog COMMAND ... + + Available commands: + %s + + Enter "%%prog help COMMAND" for information on a particular command. + """ % '\n\t'.join(["%s - %s" % (command.ljust(28), + api.command_desc.get(command)) for command in commands]) + + parser = PassiveOptionParser(usage=usage) + parser.add_option("-d", "--debug", + action="store_true", + dest="debug", + default=False, + help="Shortcut to turn on DEBUG mode for logging") + parser.add_option("-q", "--disable_logging", + action="store_true", + dest="disable_logging", + default=False, + help="Use this option to disable logging configuration") + help_commands = ['help', '-h', '--help'] + HELP = False + + try: + command = argv.pop(0) + if command in help_commands: + HELP = True + command = argv.pop(0) + except IndexError: + parser.print_help() + return + + command_func = getattr(api, command, None) + if command_func is None or command.startswith('_'): + parser.error("Invalid command %s" % command) + + parser.set_usage(inspect.getdoc(command_func)) + f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func) + for arg in f_args: + parser.add_option( + "--%s" % arg, + dest=arg, + action='store', + type="string") + + # display help of the current command + if HELP: + parser.print_help() + return + + options, args = parser.parse_args(argv) + + # override kwargs with anonymous parameters + override_kwargs = dict() + for arg in list(args): + if arg.startswith('--'): + args.remove(arg) + if '=' in arg: + opt, value = arg[2:].split('=', 1) + else: + opt = arg[2:] + value = True + override_kwargs[opt] = value + + # override kwargs with options if user is overwriting + for key, value in options.__dict__.iteritems(): + if value is not None: + override_kwargs[key] = value + + # arguments that function accepts without passed kwargs + f_required = list(f_args) + candidates = dict(kwargs) + candidates.update(override_kwargs) + for key, value in candidates.iteritems(): + if key in f_args: + f_required.remove(key) + + # map function arguments to parsed arguments + for arg in args: + try: + kw = f_required.pop(0) + except IndexError: + parser.error("Too many arguments for command %s: %s" % (command, + arg)) + kwargs[kw] = arg + + # apply overrides + kwargs.update(override_kwargs) + + # configure options + for key, value in options.__dict__.iteritems(): + kwargs.setdefault(key, value) + + # configure logging + if not asbool(kwargs.pop('disable_logging', False)): + # filter to log =< INFO into stdout and rest to stderr + class SingleLevelFilter(logging.Filter): + def __init__(self, min=None, max=None): + self.min = min or 0 + self.max = max or 100 + + def filter(self, record): + return self.min <= record.levelno <= self.max + + logger = logging.getLogger() + h1 = logging.StreamHandler(sys.stdout) + f1 = SingleLevelFilter(max=logging.INFO) + h1.addFilter(f1) + h2 = logging.StreamHandler(sys.stderr) + f2 = SingleLevelFilter(min=logging.WARN) + h2.addFilter(f2) + logger.addHandler(h1) + logger.addHandler(h2) + + if options.debug: + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(logging.INFO) + + log = logging.getLogger(__name__) + + # check if all args are given + try: + num_defaults = len(f_defaults) + except TypeError: + num_defaults = 0 + f_args_default = f_args[len(f_args) - num_defaults:] + required = list(set(f_required) - set(f_args_default)) + if required: + parser.error("Not enough arguments for command %s: %s not specified" \ + % (command, ', '.join(required))) + + # handle command + try: + ret = command_func(**kwargs) + if ret is not None: + log.info(ret) + except (exceptions.UsageError, exceptions.KnownError), e: + parser.error(e.args[0]) + +if __name__ == "__main__": + main() diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/template.py b/rhodecode/lib/dbmigrate/migrate/versioning/template.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/template.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import shutil +import sys + +from pkg_resources import resource_filename + +from migrate.versioning.config import * +from migrate.versioning import pathed + + +class Collection(pathed.Pathed): + """A collection of templates of a specific type""" + _mask = None + + def get_path(self, file): + return os.path.join(self.path, str(file)) + + +class RepositoryCollection(Collection): + _mask = '%s' + +class ScriptCollection(Collection): + _mask = '%s.py_tmpl' + +class ManageCollection(Collection): + _mask = '%s.py_tmpl' + +class SQLScriptCollection(Collection): + _mask = '%s.py_tmpl' + +class Template(pathed.Pathed): + """Finds the paths/packages of various Migrate templates. + + :param path: Templates are loaded from migrate package + if `path` is not provided. + """ + pkg = 'migrate.versioning.templates' + _manage = 'manage.py_tmpl' + + def __new__(cls, path=None): + if path is None: + path = cls._find_path(cls.pkg) + return super(Template, cls).__new__(cls, path) + + def __init__(self, path=None): + if path is None: + path = Template._find_path(self.pkg) + super(Template, self).__init__(path) + self.repository = RepositoryCollection(os.path.join(path, 'repository')) + self.script = ScriptCollection(os.path.join(path, 'script')) + self.manage = ManageCollection(os.path.join(path, 'manage')) + self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script')) + + @classmethod + def _find_path(cls, pkg): + """Returns absolute path to dotted python package.""" + tmp_pkg = pkg.rsplit('.', 1) + + if len(tmp_pkg) != 1: + return resource_filename(tmp_pkg[0], tmp_pkg[1]) + else: + return resource_filename(tmp_pkg[0], '') + + def _get_item(self, collection, theme=None): + """Locates and returns collection. + + :param collection: name of collection to locate + :param type_: type of subfolder in collection (defaults to "_default") + :returns: (package, source) + :rtype: str, str + """ + item = getattr(self, collection) + theme_mask = getattr(item, '_mask') + theme = theme_mask % (theme or 'default') + return item.get_path(theme) + + def get_repository(self, *a, **kw): + """Calls self._get_item('repository', *a, **kw)""" + return self._get_item('repository', *a, **kw) + + def get_script(self, *a, **kw): + """Calls self._get_item('script', *a, **kw)""" + return self._get_item('script', *a, **kw) + + def get_sql_script(self, *a, **kw): + """Calls self._get_item('sql_script', *a, **kw)""" + return self._get_item('sql_script', *a, **kw) + + def get_manage(self, *a, **kw): + """Calls self._get_item('manage', *a, **kw)""" + return self._get_item('manage', *a, **kw) diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/templates/__init__.py new file mode 100644 diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/manage.py_tmpl b/rhodecode/lib/dbmigrate/migrate/versioning/templates/manage.py_tmpl new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/templates/manage.py_tmpl @@ -0,0 +1,5 @@ +#!/usr/bin/env python +from migrate.versioning.shell import main + +if __name__ == '__main__': + main(%(defaults)s) diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/manage/default.py_tmpl b/rhodecode/lib/dbmigrate/migrate/versioning/templates/manage/default.py_tmpl new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/templates/manage/default.py_tmpl @@ -0,0 +1,10 @@ +#!/usr/bin/env python +from migrate.versioning.shell import main + +{{py: +_vars = locals().copy() +del _vars['__template_name__'] +_vars.pop('repository_name', None) +defaults = ", ".join(["%s='%s'" % var for var in _vars.iteritems()]) +}} +main({{ defaults }}) diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/manage/pylons.py_tmpl b/rhodecode/lib/dbmigrate/migrate/versioning/templates/manage/pylons.py_tmpl new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/templates/manage/pylons.py_tmpl @@ -0,0 +1,29 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import sys + +from sqlalchemy import engine_from_config +from paste.deploy.loadwsgi import ConfigLoader + +from migrate.versioning.shell import main +from {{ locals().pop('repository_name') }}.model import migrations + + +if '-c' in sys.argv: + pos = sys.argv.index('-c') + conf_path = sys.argv[pos + 1] + del sys.argv[pos:pos + 2] +else: + conf_path = 'development.ini' + +{{py: +_vars = locals().copy() +del _vars['__template_name__'] +defaults = ", ".join(["%s='%s'" % var for var in _vars.iteritems()]) +}} + +conf_dict = ConfigLoader(conf_path).parser._sections['app:main'] + +# migrate supports passing url as an existing Engine instance (since 0.6.0) +# usage: migrate -c path/to/config.ini COMMANDS +main(url=engine_from_config(conf_dict), repository=migrations.__path__[0],{{ defaults }}) diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/__init__.py new file mode 100644 diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/README b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/README new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/README @@ -0,0 +1,4 @@ +This is a database migration repository. + +More information at +http://code.google.com/p/sqlalchemy-migrate/ diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/__init__.py new file mode 100644 diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/migrate.cfg b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/migrate.cfg new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/migrate.cfg @@ -0,0 +1,20 @@ +[db_settings] +# Used to identify which repository this database is versioned under. +# You can use the name of your project. +repository_id={{ locals().pop('repository_id') }} + +# The name of the database table used to track the schema version. +# This name shouldn't already be used by your project. +# If this is changed once a database is under version control, you'll need to +# change the table name in each database too. +version_table={{ locals().pop('version_table') }} + +# When committing a change script, Migrate will attempt to generate the +# sql for all supported databases; normally, if one of them fails - probably +# because you don't have that database installed - it is ignored and the +# commit continues, perhaps ending successfully. +# Databases in this list MUST compile successfully during a commit, or the +# entire commit will fail. List the databases your application will actually +# be using to ensure your updates to that database work properly. +# This must be a list; example: ['postgres','sqlite'] +required_dbs={{ locals().pop('required_dbs') }} diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/versions/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/default/versions/__init__.py new file mode 100644 diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/README b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/README new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/README @@ -0,0 +1,4 @@ +This is a database migration repository. + +More information at +http://code.google.com/p/sqlalchemy-migrate/ diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/__init__.py new file mode 100644 diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/migrate.cfg b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/migrate.cfg new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/migrate.cfg @@ -0,0 +1,20 @@ +[db_settings] +# Used to identify which repository this database is versioned under. +# You can use the name of your project. +repository_id={{ locals().pop('repository_id') }} + +# The name of the database table used to track the schema version. +# This name shouldn't already be used by your project. +# If this is changed once a database is under version control, you'll need to +# change the table name in each database too. +version_table={{ locals().pop('version_table') }} + +# When committing a change script, Migrate will attempt to generate the +# sql for all supported databases; normally, if one of them fails - probably +# because you don't have that database installed - it is ignored and the +# commit continues, perhaps ending successfully. +# Databases in this list MUST compile successfully during a commit, or the +# entire commit will fail. List the databases your application will actually +# be using to ensure your updates to that database work properly. +# This must be a list; example: ['postgres','sqlite'] +required_dbs={{ locals().pop('required_dbs') }} diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/versions/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/templates/repository/pylons/versions/__init__.py new file mode 100644 diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/script/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/templates/script/__init__.py new file mode 100644 diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/script/default.py_tmpl b/rhodecode/lib/dbmigrate/migrate/versioning/templates/script/default.py_tmpl new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/templates/script/default.py_tmpl @@ -0,0 +1,11 @@ +from sqlalchemy import * +from migrate import * + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; bind migrate_engine + # to your metadata + pass + +def downgrade(migrate_engine): + # Operations to reverse the above upgrade go here. + pass diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/script/pylons.py_tmpl b/rhodecode/lib/dbmigrate/migrate/versioning/templates/script/pylons.py_tmpl new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/templates/script/pylons.py_tmpl @@ -0,0 +1,11 @@ +from sqlalchemy import * +from migrate import * + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; bind migrate_engine + # to your metadata + pass + +def downgrade(migrate_engine): + # Operations to reverse the above upgrade go here. + pass diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/sql_script/default.py_tmpl b/rhodecode/lib/dbmigrate/migrate/versioning/templates/sql_script/default.py_tmpl new file mode 100644 diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/templates/sql_script/pylons.py_tmpl b/rhodecode/lib/dbmigrate/migrate/versioning/templates/sql_script/pylons.py_tmpl new file mode 100644 diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/util/__init__.py b/rhodecode/lib/dbmigrate/migrate/versioning/util/__init__.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/util/__init__.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""".. currentmodule:: migrate.versioning.util""" + +import warnings +import logging +from decorator import decorator +from pkg_resources import EntryPoint + +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.pool import StaticPool + +from migrate import exceptions +from migrate.versioning.util.keyedinstance import KeyedInstance +from migrate.versioning.util.importpath import import_path + + +log = logging.getLogger(__name__) + +def load_model(dotted_name): + """Import module and use module-level variable". + + :param dotted_name: path to model in form of string: ``some.python.module:Class`` + + .. versionchanged:: 0.5.4 + + """ + if isinstance(dotted_name, basestring): + if ':' not in dotted_name: + # backwards compatibility + warnings.warn('model should be in form of module.model:User ' + 'and not module.model.User', exceptions.MigrateDeprecationWarning) + dotted_name = ':'.join(dotted_name.rsplit('.', 1)) + return EntryPoint.parse('x=%s' % dotted_name).load(False) + else: + # Assume it's already loaded. + return dotted_name + +def asbool(obj): + """Do everything to use object as bool""" + if isinstance(obj, basestring): + obj = obj.strip().lower() + if obj in ['true', 'yes', 'on', 'y', 't', '1']: + return True + elif obj in ['false', 'no', 'off', 'n', 'f', '0']: + return False + else: + raise ValueError("String is not true/false: %r" % obj) + if obj in (True, False): + return bool(obj) + else: + raise ValueError("String is not true/false: %r" % obj) + +def guess_obj_type(obj): + """Do everything to guess object type from string + + Tries to convert to `int`, `bool` and finally returns if not succeded. + + .. versionadded: 0.5.4 + """ + + result = None + + try: + result = int(obj) + except: + pass + + if result is None: + try: + result = asbool(obj) + except: + pass + + if result is not None: + return result + else: + return obj + +@decorator +def catch_known_errors(f, *a, **kw): + """Decorator that catches known api errors + + .. versionadded: 0.5.4 + """ + + try: + return f(*a, **kw) + except exceptions.PathFoundError, e: + raise exceptions.KnownError("The path %s already exists" % e.args[0]) + +def construct_engine(engine, **opts): + """.. versionadded:: 0.5.4 + + Constructs and returns SQLAlchemy engine. + + Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions: + + :param engine: connection string or a existing engine + :param engine_dict: python dictionary of options to pass to `create_engine` + :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`) + :type engine_dict: dict + :type engine: string or Engine instance + :type engine_arg_*: string + :returns: SQLAlchemy Engine + + .. note:: + + keyword parameters override ``engine_dict`` values. + + """ + if isinstance(engine, Engine): + return engine + elif not isinstance(engine, basestring): + raise ValueError("you need to pass either an existing engine or a database uri") + + # get options for create_engine + if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict): + kwargs = opts['engine_dict'] + else: + kwargs = dict() + + # DEPRECATED: handle echo the old way + echo = asbool(opts.get('echo', False)) + if echo: + warnings.warn('echo=True parameter is deprecated, pass ' + 'engine_arg_echo=True or engine_dict={"echo": True}', + exceptions.MigrateDeprecationWarning) + kwargs['echo'] = echo + + # parse keyword arguments + for key, value in opts.iteritems(): + if key.startswith('engine_arg_'): + kwargs[key[11:]] = guess_obj_type(value) + + log.debug('Constructing engine') + # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs) + # seems like 0.5.x branch does not work with engine.dispose and staticpool + return create_engine(engine, **kwargs) + +@decorator +def with_engine(f, *a, **kw): + """Decorator for :mod:`migrate.versioning.api` functions + to safely close resources after function usage. + + Passes engine parameters to :func:`construct_engine` and + resulting parameter is available as kw['engine']. + + Engine is disposed after wrapped function is executed. + + .. versionadded: 0.6.0 + """ + url = a[0] + engine = construct_engine(url, **kw) + + try: + kw['engine'] = engine + return f(*a, **kw) + finally: + if isinstance(engine, Engine): + log.debug('Disposing SQLAlchemy engine %s', engine) + engine.dispose() + + +class Memoize: + """Memoize(fn) - an instance which acts like fn but memoizes its arguments + Will only work on functions with non-mutable arguments + + ActiveState Code 52201 + """ + def __init__(self, fn): + self.fn = fn + self.memo = {} + + def __call__(self, *args): + if not self.memo.has_key(args): + self.memo[args] = self.fn(*args) + return self.memo[args] diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/util/importpath.py b/rhodecode/lib/dbmigrate/migrate/versioning/util/importpath.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/util/importpath.py @@ -0,0 +1,16 @@ +import os +import sys + +def import_path(fullpath): + """ Import a file with full path specification. Allows one to + import from anywhere, something __import__ does not do. + """ + # http://zephyrfalcon.org/weblog/arch_d7_2002_08_31.html + path, filename = os.path.split(fullpath) + filename, ext = os.path.splitext(filename) + sys.path.append(path) + module = __import__(filename) + reload(module) # Might be out of date during tests + del sys.path[-1] + return module + diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/util/keyedinstance.py b/rhodecode/lib/dbmigrate/migrate/versioning/util/keyedinstance.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/util/keyedinstance.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +class KeyedInstance(object): + """A class whose instances have a unique identifier of some sort + No two instances with the same unique ID should exist - if we try to create + a second instance, the first should be returned. + """ + + _instances = dict() + + def __new__(cls, *p, **k): + instances = cls._instances + clskey = str(cls) + if clskey not in instances: + instances[clskey] = dict() + instances = instances[clskey] + + key = cls._key(*p, **k) + if key not in instances: + instances[key] = super(KeyedInstance, cls).__new__(cls) + return instances[key] + + @classmethod + def _key(cls, *p, **k): + """Given a unique identifier, return a dictionary key + This should be overridden by child classes, to specify which parameters + should determine an object's uniqueness + """ + raise NotImplementedError() + + @classmethod + def clear(cls): + # Allow cls.clear() as well as uniqueInstance.clear(cls) + if str(cls) in cls._instances: + del cls._instances[str(cls)] diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/version.py b/rhodecode/lib/dbmigrate/migrate/versioning/version.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/version.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import re +import shutil +import logging + +from migrate import exceptions +from migrate.versioning import pathed, script + + +log = logging.getLogger(__name__) + +class VerNum(object): + """A version number that behaves like a string and int at the same time""" + + _instances = dict() + + def __new__(cls, value): + val = str(value) + if val not in cls._instances: + cls._instances[val] = super(VerNum, cls).__new__(cls) + ret = cls._instances[val] + return ret + + def __init__(self,value): + self.value = str(int(value)) + if self < 0: + raise ValueError("Version number cannot be negative") + + def __add__(self, value): + ret = int(self) + int(value) + return VerNum(ret) + + def __sub__(self, value): + return self + (int(value) * -1) + + def __cmp__(self, value): + return int(self) - int(value) + + def __repr__(self): + return "" % self.value + + def __str__(self): + return str(self.value) + + def __int__(self): + return int(self.value) + + +class Collection(pathed.Pathed): + """A collection of versioning scripts in a repository""" + + FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*') + + def __init__(self, path): + """Collect current version scripts in repository + and store them in self.versions + """ + super(Collection, self).__init__(path) + + # Create temporary list of files, allowing skipped version numbers. + files = os.listdir(path) + if '1' in files: + # deprecation + raise Exception('It looks like you have a repository in the old ' + 'format (with directories for each version). ' + 'Please convert repository before proceeding.') + + tempVersions = dict() + for filename in files: + match = self.FILENAME_WITH_VERSION.match(filename) + if match: + num = int(match.group(1)) + tempVersions.setdefault(num, []).append(filename) + else: + pass # Must be a helper file or something, let's ignore it. + + # Create the versions member where the keys + # are VerNum's and the values are Version's. + self.versions = dict() + for num, files in tempVersions.items(): + self.versions[VerNum(num)] = Version(num, path, files) + + @property + def latest(self): + """:returns: Latest version in Collection""" + return max([VerNum(0)] + self.versions.keys()) + + def create_new_python_version(self, description, **k): + """Create Python files for new version""" + ver = self.latest + 1 + extra = str_to_filename(description) + + if extra: + if extra == '_': + extra = '' + elif not extra.startswith('_'): + extra = '_%s' % extra + + filename = '%03d%s.py' % (ver, extra) + filepath = self._version_path(filename) + + script.PythonScript.create(filepath, **k) + self.versions[ver] = Version(ver, self.path, [filename]) + + def create_new_sql_version(self, database, **k): + """Create SQL files for new version""" + ver = self.latest + 1 + self.versions[ver] = Version(ver, self.path, []) + + # Create new files. + for op in ('upgrade', 'downgrade'): + filename = '%03d_%s_%s.sql' % (ver, database, op) + filepath = self._version_path(filename) + script.SqlScript.create(filepath, **k) + self.versions[ver].add_script(filepath) + + def version(self, vernum=None): + """Returns latest Version if vernum is not given. + Otherwise, returns wanted version""" + if vernum is None: + vernum = self.latest + return self.versions[VerNum(vernum)] + + @classmethod + def clear(cls): + super(Collection, cls).clear() + + def _version_path(self, ver): + """Returns path of file in versions repository""" + return os.path.join(self.path, str(ver)) + + +class Version(object): + """A single version in a collection + :param vernum: Version Number + :param path: Path to script files + :param filelist: List of scripts + :type vernum: int, VerNum + :type path: string + :type filelist: list + """ + + def __init__(self, vernum, path, filelist): + self.version = VerNum(vernum) + + # Collect scripts in this folder + self.sql = dict() + self.python = None + + for script in filelist: + self.add_script(os.path.join(path, script)) + + def script(self, database=None, operation=None): + """Returns SQL or Python Script""" + for db in (database, 'default'): + # Try to return a .sql script first + try: + return self.sql[db][operation] + except KeyError: + continue # No .sql script exists + + # TODO: maybe add force Python parameter? + ret = self.python + + assert ret is not None, \ + "There is no script for %d version" % self.version + return ret + + def add_script(self, path): + """Add script to Collection/Version""" + if path.endswith(Extensions.py): + self._add_script_py(path) + elif path.endswith(Extensions.sql): + self._add_script_sql(path) + + SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql') + + def _add_script_sql(self, path): + basename = os.path.basename(path) + match = self.SQL_FILENAME.match(basename) + + if match: + version, dbms, op = match.group(1), match.group(2), match.group(3) + else: + raise exceptions.ScriptError( + "Invalid SQL script name %s " % basename + \ + "(needs to be ###_database_operation.sql)") + + # File the script into a dictionary + self.sql.setdefault(dbms, {})[op] = script.SqlScript(path) + + def _add_script_py(self, path): + if self.python is not None: + raise exceptions.ScriptError('You can only have one Python script ' + 'per version, but you have: %s and %s' % (self.python, path)) + self.python = script.PythonScript(path) + + +class Extensions: + """A namespace for file extensions""" + py = 'py' + sql = 'sql' + +def str_to_filename(s): + """Replaces spaces, (double and single) quotes + and double underscores to underscores + """ + + s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_") + while '__' in s: + s = s.replace('__', '_') + return s diff --git a/rhodecode/lib/dbmigrate/versions/001_initial_release.py b/rhodecode/lib/dbmigrate/versions/001_initial_release.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/versions/001_initial_release.py @@ -0,0 +1,238 @@ +from migrate import * + +#============================================================================== +# DB INITIAL MODEL +#============================================================================== +import logging +import datetime + +from sqlalchemy import * +from sqlalchemy.exc import DatabaseError +from sqlalchemy.orm import relation, backref, class_mapper +from sqlalchemy.orm.session import Session + +from rhodecode.model.meta import Base + +log = logging.getLogger(__name__) + +class BaseModel(object): + + @classmethod + def _get_keys(cls): + """return column names for this model """ + return class_mapper(cls).c.keys() + + def get_dict(self): + """return dict with keys and values corresponding + to this model data """ + + d = {} + for k in self._get_keys(): + d[k] = getattr(self, k) + return d + + def get_appstruct(self): + """return list with keys and values tupples corresponding + to this model data """ + + l = [] + for k in self._get_keys(): + l.append((k, getattr(self, k),)) + return l + + def populate_obj(self, populate_dict): + """populate model with data from given populate_dict""" + + for k in self._get_keys(): + if k in populate_dict: + setattr(self, k, populate_dict[k]) + +class RhodeCodeSettings(Base, BaseModel): + __tablename__ = 'rhodecode_settings' + __table_args__ = (UniqueConstraint('app_settings_name'), {'useexisting':True}) + app_settings_id = Column("app_settings_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + app_settings_name = Column("app_settings_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + app_settings_value = Column("app_settings_value", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + + def __init__(self, k, v): + self.app_settings_name = k + self.app_settings_value = v + + def __repr__(self): + return "" % (self.app_settings_name, + self.app_settings_value) + +class RhodeCodeUi(Base, BaseModel): + __tablename__ = 'rhodecode_ui' + __table_args__ = {'useexisting':True} + ui_id = Column("ui_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + ui_section = Column("ui_section", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + ui_key = Column("ui_key", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + ui_value = Column("ui_value", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + ui_active = Column("ui_active", Boolean(), nullable=True, unique=None, default=True) + + +class User(Base, BaseModel): + __tablename__ = 'users' + __table_args__ = (UniqueConstraint('username'), UniqueConstraint('email'), {'useexisting':True}) + user_id = Column("user_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + username = Column("username", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + password = Column("password", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + active = Column("active", Boolean(), nullable=True, unique=None, default=None) + admin = Column("admin", Boolean(), nullable=True, unique=None, default=False) + name = Column("name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + lastname = Column("lastname", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + email = Column("email", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + last_login = Column("last_login", DateTime(timezone=False), nullable=True, unique=None, default=None) + is_ldap = Column("is_ldap", Boolean(), nullable=False, unique=None, default=False) + + user_log = relation('UserLog', cascade='all') + user_perms = relation('UserToPerm', primaryjoin="User.user_id==UserToPerm.user_id", cascade='all') + + repositories = relation('Repository') + user_followers = relation('UserFollowing', primaryjoin='UserFollowing.follows_user_id==User.user_id', cascade='all') + + @property + def full_contact(self): + return '%s %s <%s>' % (self.name, self.lastname, self.email) + + def __repr__(self): + return "" % (self.user_id, self.username) + + def update_lastlogin(self): + """Update user lastlogin""" + + try: + session = Session.object_session(self) + self.last_login = datetime.datetime.now() + session.add(self) + session.commit() + log.debug('updated user %s lastlogin', self.username) + except (DatabaseError,): + session.rollback() + + +class UserLog(Base, BaseModel): + __tablename__ = 'user_logs' + __table_args__ = {'useexisting':True} + user_log_id = Column("user_log_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None) + repository_id = Column("repository_id", Integer(length=None, convert_unicode=False, assert_unicode=None), ForeignKey(u'repositories.repo_id'), nullable=False, unique=None, default=None) + repository_name = Column("repository_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + user_ip = Column("user_ip", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + action = Column("action", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + action_date = Column("action_date", DateTime(timezone=False), nullable=True, unique=None, default=None) + + user = relation('User') + repository = relation('Repository') + +class Repository(Base, BaseModel): + __tablename__ = 'repositories' + __table_args__ = (UniqueConstraint('repo_name'), {'useexisting':True},) + repo_id = Column("repo_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + repo_name = Column("repo_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=False, unique=True, default=None) + repo_type = Column("repo_type", String(length=None, convert_unicode=False, assert_unicode=None), nullable=False, unique=False, default=None) + user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=False, default=None) + private = Column("private", Boolean(), nullable=True, unique=None, default=None) + enable_statistics = Column("statistics", Boolean(), nullable=True, unique=None, default=True) + description = Column("description", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + fork_id = Column("fork_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=True, unique=False, default=None) + + user = relation('User') + fork = relation('Repository', remote_side=repo_id) + repo_to_perm = relation('RepoToPerm', cascade='all') + stats = relation('Statistics', cascade='all', uselist=False) + + repo_followers = relation('UserFollowing', primaryjoin='UserFollowing.follows_repo_id==Repository.repo_id', cascade='all') + + + def __repr__(self): + return "" % (self.repo_id, self.repo_name) + +class Permission(Base, BaseModel): + __tablename__ = 'permissions' + __table_args__ = {'useexisting':True} + permission_id = Column("permission_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + permission_name = Column("permission_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + permission_longname = Column("permission_longname", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + + def __repr__(self): + return "" % (self.permission_id, self.permission_name) + +class RepoToPerm(Base, BaseModel): + __tablename__ = 'repo_to_perm' + __table_args__ = (UniqueConstraint('user_id', 'repository_id'), {'useexisting':True}) + repo_to_perm_id = Column("repo_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None) + permission_id = Column("permission_id", Integer(), ForeignKey(u'permissions.permission_id'), nullable=False, unique=None, default=None) + repository_id = Column("repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=False, unique=None, default=None) + + user = relation('User') + permission = relation('Permission') + repository = relation('Repository') + +class UserToPerm(Base, BaseModel): + __tablename__ = 'user_to_perm' + __table_args__ = (UniqueConstraint('user_id', 'permission_id'), {'useexisting':True}) + user_to_perm_id = Column("user_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None) + permission_id = Column("permission_id", Integer(), ForeignKey(u'permissions.permission_id'), nullable=False, unique=None, default=None) + + user = relation('User') + permission = relation('Permission') + +class Statistics(Base, BaseModel): + __tablename__ = 'statistics' + __table_args__ = (UniqueConstraint('repository_id'), {'useexisting':True}) + stat_id = Column("stat_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + repository_id = Column("repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=False, unique=True, default=None) + stat_on_revision = Column("stat_on_revision", Integer(), nullable=False) + commit_activity = Column("commit_activity", LargeBinary(), nullable=False)#JSON data + commit_activity_combined = Column("commit_activity_combined", LargeBinary(), nullable=False)#JSON data + languages = Column("languages", LargeBinary(), nullable=False)#JSON data + + repository = relation('Repository', single_parent=True) + +class UserFollowing(Base, BaseModel): + __tablename__ = 'user_followings' + __table_args__ = (UniqueConstraint('user_id', 'follows_repository_id'), + UniqueConstraint('user_id', 'follows_user_id') + , {'useexisting':True}) + + user_following_id = Column("user_following_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None) + follows_repo_id = Column("follows_repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=True, unique=None, default=None) + follows_user_id = Column("follows_user_id", Integer(), ForeignKey(u'users.user_id'), nullable=True, unique=None, default=None) + + user = relation('User', primaryjoin='User.user_id==UserFollowing.user_id') + + follows_user = relation('User', primaryjoin='User.user_id==UserFollowing.follows_user_id') + follows_repository = relation('Repository') + + +class CacheInvalidation(Base, BaseModel): + __tablename__ = 'cache_invalidation' + __table_args__ = (UniqueConstraint('cache_key'), {'useexisting':True}) + cache_id = Column("cache_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + cache_key = Column("cache_key", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + cache_args = Column("cache_args", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + cache_active = Column("cache_active", Boolean(), nullable=True, unique=None, default=False) + + + def __init__(self, cache_key, cache_args=''): + self.cache_key = cache_key + self.cache_args = cache_args + self.cache_active = False + + def __repr__(self): + return "" % (self.cache_id, self.cache_key) + + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; bind migrate_engine + # to your metadata + Base.metadata.create_all(bind=migrate_engine, checkfirst=False) + +def downgrade(migrate_engine): + # Operations to reverse the above upgrade go here. + Base.metadata.drop_all(bind=migrate_engine, checkfirst=False) diff --git a/rhodecode/lib/dbmigrate/versions/002_version_1_1_0.py b/rhodecode/lib/dbmigrate/versions/002_version_1_1_0.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/versions/002_version_1_1_0.py @@ -0,0 +1,118 @@ +from sqlalchemy import * +from sqlalchemy.orm import relation + +from migrate import * +from migrate.changeset import * +from rhodecode.model.meta import Base, BaseModel + +def upgrade(migrate_engine): + """ Upgrade operations go here. + Don't create your own engine; bind migrate_engine to your metadata + """ + + #========================================================================== + # Upgrade of `users` table + #========================================================================== + tblname = 'users' + tbl = Table(tblname, MetaData(bind=migrate_engine), autoload=True, + autoload_with=migrate_engine) + + #ADD is_ldap column + is_ldap = Column("is_ldap", Boolean(), nullable=False, + unique=None, default=False) + is_ldap.create(tbl) + + + #========================================================================== + # Upgrade of `user_logs` table + #========================================================================== + + tblname = 'users' + tbl = Table(tblname, MetaData(bind=migrate_engine), autoload=True, + autoload_with=migrate_engine) + + #ADD revision column + revision = Column('revision', TEXT(length=None, convert_unicode=False, + assert_unicode=None), + nullable=True, unique=None, default=None) + revision.create(tbl) + + + + #========================================================================== + # Upgrade of `repositories` table + #========================================================================== + tblname = 'users' + tbl = Table(tblname, MetaData(bind=migrate_engine), autoload=True, + autoload_with=migrate_engine) + + #ADD repo_type column + repo_type = Column("repo_type", String(length=None, convert_unicode=False, + assert_unicode=None), + nullable=False, unique=False, default=None) + repo_type.create(tbl) + + + #ADD statistics column + enable_statistics = Column("statistics", Boolean(), nullable=True, + unique=None, default=True) + enable_statistics.create(tbl) + + + + #========================================================================== + # Add table `user_followings` + #========================================================================== + tblname = 'user_followings' + class UserFollowing(Base, BaseModel): + __tablename__ = 'user_followings' + __table_args__ = (UniqueConstraint('user_id', 'follows_repository_id'), + UniqueConstraint('user_id', 'follows_user_id') + , {'useexisting':True}) + + user_following_id = Column("user_following_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None) + follows_repo_id = Column("follows_repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=True, unique=None, default=None) + follows_user_id = Column("follows_user_id", Integer(), ForeignKey(u'users.user_id'), nullable=True, unique=None, default=None) + + user = relation('User', primaryjoin='User.user_id==UserFollowing.user_id') + + follows_user = relation('User', primaryjoin='User.user_id==UserFollowing.follows_user_id') + follows_repository = relation('Repository') + + Base.metadata.tables[tblname].create(migrate_engine) + + #========================================================================== + # Add table `cache_invalidation` + #========================================================================== + class CacheInvalidation(Base, BaseModel): + __tablename__ = 'cache_invalidation' + __table_args__ = (UniqueConstraint('cache_key'), {'useexisting':True}) + cache_id = Column("cache_id", Integer(), nullable=False, unique=True, default=None, primary_key=True) + cache_key = Column("cache_key", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + cache_args = Column("cache_args", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) + cache_active = Column("cache_active", Boolean(), nullable=True, unique=None, default=False) + + + def __init__(self, cache_key, cache_args=''): + self.cache_key = cache_key + self.cache_args = cache_args + self.cache_active = False + + def __repr__(self): + return "" % (self.cache_id, self.cache_key) + + Base.metadata.tables[tblname].create(migrate_engine) + + return + + + + + + +def downgrade(migrate_engine): + meta = MetaData() + meta.bind = migrate_engine + + diff --git a/rhodecode/lib/dbmigrate/versions/__init__.py b/rhodecode/lib/dbmigrate/versions/__init__.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/versions/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +""" + rhodecode.lib.dbmigrate.versions.__init__ + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Package containing new versions of database models + + :created_on: Dec 11, 2010 + :author: marcink + :copyright: (C) 2009-2010 Marcin Kuzminski + :license: GPLv3, see COPYING for more details. +""" +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; version 2 +# of the License or (at your opinion) any later version of the license. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, +# MA 02110-1301, USA. diff --git a/rhodecode/lib/utils.py b/rhodecode/lib/utils.py --- a/rhodecode/lib/utils.py +++ b/rhodecode/lib/utils.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ - package.rhodecode.lib.utils - ~~~~~~~~~~~~~~ + rhodecode.lib.utils + ~~~~~~~~~~~~~~~~~~~ Utilities library for RhodeCode @@ -599,30 +599,3 @@ class BasePasterCommand(Command): path_to_ini_file = os.path.realpath(conf) conf = paste.deploy.appconfig('config:' + path_to_ini_file) pylonsconfig.init_app(conf.global_conf, conf.local_conf) - - - -class UpgradeDb(BasePasterCommand): - """Command used for paster to upgrade our database to newer version - """ - - max_args = 1 - min_args = 1 - - usage = "CONFIG_FILE" - summary = "Upgrades current db to newer version given configuration file" - group_name = "RhodeCode" - - parser = Command.standard_parser(verbose=True) - - def command(self): - from pylons import config - raise NotImplementedError('Not implemented yet') - - - def update_parser(self): - self.parser.add_option('--sql', - action='store_true', - dest='just_sql', - help="Prints upgrade sql for further investigation", - default=False) diff --git a/setup.py b/setup.py --- a/setup.py +++ b/setup.py @@ -93,7 +93,7 @@ setup( [paste.global_paster_command] make-index = rhodecode.lib.indexers:MakeIndex - upgrade-db = rhodecode.lib.utils:UpgradeDb + upgrade-db = rhodecode.lib.dbmigrate:UpgradeDb celeryd=rhodecode.lib.celerypylons.commands:CeleryDaemonCommand celerybeat=rhodecode.lib.celerypylons.commands:CeleryBeatCommand camqadm=rhodecode.lib.celerypylons.commands:CAMQPAdminCommand