##// END OF EJS Templates
updated sqlalchemy migrate to latest version
marcink -
r1061:9bb609d1 beta
parent child Browse files
Show More
@@ -1,80 +1,99 b''
1 """
1 """
2 Firebird database specific implementations of changeset classes.
2 Firebird database specific implementations of changeset classes.
3 """
3 """
4 from sqlalchemy.databases import firebird as sa_base
4 from sqlalchemy.databases import firebird as sa_base
5
5 from sqlalchemy.schema import PrimaryKeyConstraint
6 from rhodecode.lib.dbmigrate.migrate import exceptions
6 from rhodecode.lib.dbmigrate.migrate import exceptions
7 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
7 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
8
8
9
9
10 if SQLA_06:
10 if SQLA_06:
11 FBSchemaGenerator = sa_base.FBDDLCompiler
11 FBSchemaGenerator = sa_base.FBDDLCompiler
12 else:
12 else:
13 FBSchemaGenerator = sa_base.FBSchemaGenerator
13 FBSchemaGenerator = sa_base.FBSchemaGenerator
14
14
15 class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
15 class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
16 """Firebird column generator implementation."""
16 """Firebird column generator implementation."""
17
17
18
18
19 class FBColumnDropper(ansisql.ANSIColumnDropper):
19 class FBColumnDropper(ansisql.ANSIColumnDropper):
20 """Firebird column dropper implementation."""
20 """Firebird column dropper implementation."""
21
21
22 def visit_column(self, column):
22 def visit_column(self, column):
23 """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax
23 """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax
24
24
25 Drop primary key and unique constraints if dropped column is referencing it."""
25 Drop primary key and unique constraints if dropped column is referencing it."""
26 if column.primary_key:
26 if column.primary_key:
27 if column.table.primary_key.columns.contains_column(column):
27 if column.table.primary_key.columns.contains_column(column):
28 column.table.primary_key.drop()
28 column.table.primary_key.drop()
29 # TODO: recreate primary key if it references more than this column
29 # TODO: recreate primary key if it references more than this column
30 if column.unique or getattr(column, 'unique_name', None):
30
31 for index in column.table.indexes:
32 # "column in index.columns" causes problems as all
33 # column objects compare equal and return a SQL expression
34 if column.name in [col.name for col in index.columns]:
35 index.drop()
36 # TODO: recreate index if it references more than this column
37
31 for cons in column.table.constraints:
38 for cons in column.table.constraints:
32 if cons.contains_column(column):
39 if isinstance(cons,PrimaryKeyConstraint):
33 cons.drop()
40 # will be deleted only when the column its on
41 # is deleted!
42 continue
43
44 if SQLA_06:
45 should_drop = column.name in cons.columns
46 else:
47 should_drop = cons.contains_column(column) and cons.name
48 if should_drop:
49 self.start_alter_table(column)
50 self.append("DROP CONSTRAINT ")
51 self.append(self.preparer.format_constraint(cons))
52 self.execute()
34 # TODO: recreate unique constraint if it refenrences more than this column
53 # TODO: recreate unique constraint if it refenrences more than this column
35
54
36 table = self.start_alter_table(column)
55 self.start_alter_table(column)
37 self.append('DROP %s' % self.preparer.format_column(column))
56 self.append('DROP %s' % self.preparer.format_column(column))
38 self.execute()
57 self.execute()
39
58
40
59
41 class FBSchemaChanger(ansisql.ANSISchemaChanger):
60 class FBSchemaChanger(ansisql.ANSISchemaChanger):
42 """Firebird schema changer implementation."""
61 """Firebird schema changer implementation."""
43
62
44 def visit_table(self, table):
63 def visit_table(self, table):
45 """Rename table not supported"""
64 """Rename table not supported"""
46 raise exceptions.NotSupportedError(
65 raise exceptions.NotSupportedError(
47 "Firebird does not support renaming tables.")
66 "Firebird does not support renaming tables.")
48
67
49 def _visit_column_name(self, table, column, delta):
68 def _visit_column_name(self, table, column, delta):
50 self.start_alter_table(table)
69 self.start_alter_table(table)
51 col_name = self.preparer.quote(delta.current_name, table.quote)
70 col_name = self.preparer.quote(delta.current_name, table.quote)
52 new_name = self.preparer.format_column(delta.result_column)
71 new_name = self.preparer.format_column(delta.result_column)
53 self.append('ALTER COLUMN %s TO %s' % (col_name, new_name))
72 self.append('ALTER COLUMN %s TO %s' % (col_name, new_name))
54
73
55 def _visit_column_nullable(self, table, column, delta):
74 def _visit_column_nullable(self, table, column, delta):
56 """Changing NULL is not supported"""
75 """Changing NULL is not supported"""
57 # TODO: http://www.firebirdfaq.org/faq103/
76 # TODO: http://www.firebirdfaq.org/faq103/
58 raise exceptions.NotSupportedError(
77 raise exceptions.NotSupportedError(
59 "Firebird does not support altering NULL bevahior.")
78 "Firebird does not support altering NULL bevahior.")
60
79
61
80
62 class FBConstraintGenerator(ansisql.ANSIConstraintGenerator):
81 class FBConstraintGenerator(ansisql.ANSIConstraintGenerator):
63 """Firebird constraint generator implementation."""
82 """Firebird constraint generator implementation."""
64
83
65
84
66 class FBConstraintDropper(ansisql.ANSIConstraintDropper):
85 class FBConstraintDropper(ansisql.ANSIConstraintDropper):
67 """Firebird constaint dropper implementation."""
86 """Firebird constaint dropper implementation."""
68
87
69 def cascade_constraint(self, constraint):
88 def cascade_constraint(self, constraint):
70 """Cascading constraints is not supported"""
89 """Cascading constraints is not supported"""
71 raise exceptions.NotSupportedError(
90 raise exceptions.NotSupportedError(
72 "Firebird does not support cascading constraints")
91 "Firebird does not support cascading constraints")
73
92
74
93
75 class FBDialect(ansisql.ANSIDialect):
94 class FBDialect(ansisql.ANSIDialect):
76 columngenerator = FBColumnGenerator
95 columngenerator = FBColumnGenerator
77 columndropper = FBColumnDropper
96 columndropper = FBColumnDropper
78 schemachanger = FBSchemaChanger
97 schemachanger = FBSchemaChanger
79 constraintgenerator = FBConstraintGenerator
98 constraintgenerator = FBConstraintGenerator
80 constraintdropper = FBConstraintDropper
99 constraintdropper = FBConstraintDropper
@@ -1,148 +1,155 b''
1 """
1 """
2 `SQLite`_ database specific implementations of changeset classes.
2 `SQLite`_ database specific implementations of changeset classes.
3
3
4 .. _`SQLite`: http://www.sqlite.org/
4 .. _`SQLite`: http://www.sqlite.org/
5 """
5 """
6 from UserDict import DictMixin
6 from UserDict import DictMixin
7 from copy import copy
7 from copy import copy
8
8
9 from sqlalchemy.databases import sqlite as sa_base
9 from sqlalchemy.databases import sqlite as sa_base
10
10
11 from rhodecode.lib.dbmigrate.migrate import exceptions
11 from rhodecode.lib.dbmigrate.migrate import exceptions
12 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
12 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
13
13
14
14
15 if not SQLA_06:
15 if not SQLA_06:
16 SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
16 SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
17 else:
17 else:
18 SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
18 SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
19
19
20 class SQLiteCommon(object):
20 class SQLiteCommon(object):
21
21
22 def _not_supported(self, op):
22 def _not_supported(self, op):
23 raise exceptions.NotSupportedError("SQLite does not support "
23 raise exceptions.NotSupportedError("SQLite does not support "
24 "%s; see http://www.sqlite.org/lang_altertable.html" % op)
24 "%s; see http://www.sqlite.org/lang_altertable.html" % op)
25
25
26
26
27 class SQLiteHelper(SQLiteCommon):
27 class SQLiteHelper(SQLiteCommon):
28
28
29 def recreate_table(self,table,column=None,delta=None):
29 def recreate_table(self,table,column=None,delta=None):
30 table_name = self.preparer.format_table(table)
30 table_name = self.preparer.format_table(table)
31
31
32 # we remove all indexes so as not to have
32 # we remove all indexes so as not to have
33 # problems during copy and re-create
33 # problems during copy and re-create
34 for index in table.indexes:
34 for index in table.indexes:
35 index.drop()
35 index.drop()
36
36
37 self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
37 self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
38 self.execute()
38 self.execute()
39
39
40 insertion_string = self._modify_table(table, column, delta)
40 insertion_string = self._modify_table(table, column, delta)
41
41
42 table.create()
42 table.create()
43 self.append(insertion_string % {'table_name': table_name})
43 self.append(insertion_string % {'table_name': table_name})
44 self.execute()
44 self.execute()
45 self.append('DROP TABLE migration_tmp')
45 self.append('DROP TABLE migration_tmp')
46 self.execute()
46 self.execute()
47
47
48 def visit_column(self, delta):
48 def visit_column(self, delta):
49 if isinstance(delta, DictMixin):
49 if isinstance(delta, DictMixin):
50 column = delta.result_column
50 column = delta.result_column
51 table = self._to_table(delta.table)
51 table = self._to_table(delta.table)
52 else:
52 else:
53 column = delta
53 column = delta
54 table = self._to_table(column.table)
54 table = self._to_table(column.table)
55 self.recreate_table(table,column,delta)
55 self.recreate_table(table,column,delta)
56
56
57 class SQLiteColumnGenerator(SQLiteSchemaGenerator,
57 class SQLiteColumnGenerator(SQLiteSchemaGenerator,
58 ansisql.ANSIColumnGenerator,
58 ansisql.ANSIColumnGenerator,
59 # at the end so we get the normal
59 # at the end so we get the normal
60 # visit_column by default
60 # visit_column by default
61 SQLiteHelper,
61 SQLiteHelper,
62 SQLiteCommon
62 SQLiteCommon
63 ):
63 ):
64 """SQLite ColumnGenerator"""
64 """SQLite ColumnGenerator"""
65
65
66 def _modify_table(self, table, column, delta):
66 def _modify_table(self, table, column, delta):
67 columns = ' ,'.join(map(
67 columns = ' ,'.join(map(
68 self.preparer.format_column,
68 self.preparer.format_column,
69 [c for c in table.columns if c.name!=column.name]))
69 [c for c in table.columns if c.name!=column.name]))
70 return ('INSERT INTO %%(table_name)s (%(cols)s) '
70 return ('INSERT INTO %%(table_name)s (%(cols)s) '
71 'SELECT %(cols)s from migration_tmp')%{'cols':columns}
71 'SELECT %(cols)s from migration_tmp')%{'cols':columns}
72
72
73 def visit_column(self,column):
73 def visit_column(self,column):
74 if column.foreign_keys:
74 if column.foreign_keys:
75 SQLiteHelper.visit_column(self,column)
75 SQLiteHelper.visit_column(self,column)
76 else:
76 else:
77 super(SQLiteColumnGenerator,self).visit_column(column)
77 super(SQLiteColumnGenerator,self).visit_column(column)
78
78
79 class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper):
79 class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper):
80 """SQLite ColumnDropper"""
80 """SQLite ColumnDropper"""
81
81
82 def _modify_table(self, table, column, delta):
82 def _modify_table(self, table, column, delta):
83
83 columns = ' ,'.join(map(self.preparer.format_column, table.columns))
84 columns = ' ,'.join(map(self.preparer.format_column, table.columns))
84 return 'INSERT INTO %(table_name)s SELECT ' + columns + \
85 return 'INSERT INTO %(table_name)s SELECT ' + columns + \
85 ' from migration_tmp'
86 ' from migration_tmp'
86
87
88 def visit_column(self,column):
89 # For SQLite, we *have* to remove the column here so the table
90 # is re-created properly.
91 column.remove_from_table(column.table,unset_table=False)
92 super(SQLiteColumnDropper,self).visit_column(column)
93
87
94
88 class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger):
95 class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger):
89 """SQLite SchemaChanger"""
96 """SQLite SchemaChanger"""
90
97
91 def _modify_table(self, table, column, delta):
98 def _modify_table(self, table, column, delta):
92 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
99 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
93
100
94 def visit_index(self, index):
101 def visit_index(self, index):
95 """Does not support ALTER INDEX"""
102 """Does not support ALTER INDEX"""
96 self._not_supported('ALTER INDEX')
103 self._not_supported('ALTER INDEX')
97
104
98
105
99 class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon):
106 class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon):
100
107
101 def visit_migrate_primary_key_constraint(self, constraint):
108 def visit_migrate_primary_key_constraint(self, constraint):
102 tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )"
109 tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )"
103 cols = ', '.join(map(self.preparer.format_column, constraint.columns))
110 cols = ', '.join(map(self.preparer.format_column, constraint.columns))
104 tname = self.preparer.format_table(constraint.table)
111 tname = self.preparer.format_table(constraint.table)
105 name = self.get_constraint_name(constraint)
112 name = self.get_constraint_name(constraint)
106 msg = tmpl % (name, tname, cols)
113 msg = tmpl % (name, tname, cols)
107 self.append(msg)
114 self.append(msg)
108 self.execute()
115 self.execute()
109
116
110 def _modify_table(self, table, column, delta):
117 def _modify_table(self, table, column, delta):
111 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
118 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
112
119
113 def visit_migrate_foreign_key_constraint(self, *p, **k):
120 def visit_migrate_foreign_key_constraint(self, *p, **k):
114 self.recreate_table(p[0].table)
121 self.recreate_table(p[0].table)
115
122
116 def visit_migrate_unique_constraint(self, *p, **k):
123 def visit_migrate_unique_constraint(self, *p, **k):
117 self.recreate_table(p[0].table)
124 self.recreate_table(p[0].table)
118
125
119
126
120 class SQLiteConstraintDropper(ansisql.ANSIColumnDropper,
127 class SQLiteConstraintDropper(ansisql.ANSIColumnDropper,
121 SQLiteCommon,
128 SQLiteCommon,
122 ansisql.ANSIConstraintCommon):
129 ansisql.ANSIConstraintCommon):
123
130
124 def visit_migrate_primary_key_constraint(self, constraint):
131 def visit_migrate_primary_key_constraint(self, constraint):
125 tmpl = "DROP INDEX %s "
132 tmpl = "DROP INDEX %s "
126 name = self.get_constraint_name(constraint)
133 name = self.get_constraint_name(constraint)
127 msg = tmpl % (name)
134 msg = tmpl % (name)
128 self.append(msg)
135 self.append(msg)
129 self.execute()
136 self.execute()
130
137
131 def visit_migrate_foreign_key_constraint(self, *p, **k):
138 def visit_migrate_foreign_key_constraint(self, *p, **k):
132 self._not_supported('ALTER TABLE DROP CONSTRAINT')
139 self._not_supported('ALTER TABLE DROP CONSTRAINT')
133
140
134 def visit_migrate_check_constraint(self, *p, **k):
141 def visit_migrate_check_constraint(self, *p, **k):
135 self._not_supported('ALTER TABLE DROP CONSTRAINT')
142 self._not_supported('ALTER TABLE DROP CONSTRAINT')
136
143
137 def visit_migrate_unique_constraint(self, *p, **k):
144 def visit_migrate_unique_constraint(self, *p, **k):
138 self._not_supported('ALTER TABLE DROP CONSTRAINT')
145 self._not_supported('ALTER TABLE DROP CONSTRAINT')
139
146
140
147
141 # TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index
148 # TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index
142
149
143 class SQLiteDialect(ansisql.ANSIDialect):
150 class SQLiteDialect(ansisql.ANSIDialect):
144 columngenerator = SQLiteColumnGenerator
151 columngenerator = SQLiteColumnGenerator
145 columndropper = SQLiteColumnDropper
152 columndropper = SQLiteColumnDropper
146 schemachanger = SQLiteSchemaChanger
153 schemachanger = SQLiteSchemaChanger
147 constraintgenerator = SQLiteConstraintGenerator
154 constraintgenerator = SQLiteConstraintGenerator
148 constraintdropper = SQLiteConstraintDropper
155 constraintdropper = SQLiteConstraintDropper
@@ -1,669 +1,651 b''
1 """
1 """
2 Schema module providing common schema operations.
2 Schema module providing common schema operations.
3 """
3 """
4 import warnings
4 import warnings
5
5
6 from UserDict import DictMixin
6 from UserDict import DictMixin
7
7
8 import sqlalchemy
8 import sqlalchemy
9
9
10 from sqlalchemy.schema import ForeignKeyConstraint
10 from sqlalchemy.schema import ForeignKeyConstraint
11 from sqlalchemy.schema import UniqueConstraint
11 from sqlalchemy.schema import UniqueConstraint
12
12
13 from rhodecode.lib.dbmigrate.migrate.exceptions import *
13 from rhodecode.lib.dbmigrate.migrate.exceptions import *
14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
15 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor,
15 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor,
16 run_single_visitor)
16 run_single_visitor)
17
17
18
18
19 __all__ = [
19 __all__ = [
20 'create_column',
20 'create_column',
21 'drop_column',
21 'drop_column',
22 'alter_column',
22 'alter_column',
23 'rename_table',
23 'rename_table',
24 'rename_index',
24 'rename_index',
25 'ChangesetTable',
25 'ChangesetTable',
26 'ChangesetColumn',
26 'ChangesetColumn',
27 'ChangesetIndex',
27 'ChangesetIndex',
28 'ChangesetDefaultClause',
28 'ChangesetDefaultClause',
29 'ColumnDelta',
29 'ColumnDelta',
30 ]
30 ]
31
31
32 DEFAULT_ALTER_METADATA = True
33
34
35 def create_column(column, table=None, *p, **kw):
32 def create_column(column, table=None, *p, **kw):
36 """Create a column, given the table.
33 """Create a column, given the table.
37
34
38 API to :meth:`ChangesetColumn.create`.
35 API to :meth:`ChangesetColumn.create`.
39 """
36 """
40 if table is not None:
37 if table is not None:
41 return table.create_column(column, *p, **kw)
38 return table.create_column(column, *p, **kw)
42 return column.create(*p, **kw)
39 return column.create(*p, **kw)
43
40
44
41
45 def drop_column(column, table=None, *p, **kw):
42 def drop_column(column, table=None, *p, **kw):
46 """Drop a column, given the table.
43 """Drop a column, given the table.
47
44
48 API to :meth:`ChangesetColumn.drop`.
45 API to :meth:`ChangesetColumn.drop`.
49 """
46 """
50 if table is not None:
47 if table is not None:
51 return table.drop_column(column, *p, **kw)
48 return table.drop_column(column, *p, **kw)
52 return column.drop(*p, **kw)
49 return column.drop(*p, **kw)
53
50
54
51
55 def rename_table(table, name, engine=None, **kw):
52 def rename_table(table, name, engine=None, **kw):
56 """Rename a table.
53 """Rename a table.
57
54
58 If Table instance is given, engine is not used.
55 If Table instance is given, engine is not used.
59
56
60 API to :meth:`ChangesetTable.rename`.
57 API to :meth:`ChangesetTable.rename`.
61
58
62 :param table: Table to be renamed.
59 :param table: Table to be renamed.
63 :param name: New name for Table.
60 :param name: New name for Table.
64 :param engine: Engine instance.
61 :param engine: Engine instance.
65 :type table: string or Table instance
62 :type table: string or Table instance
66 :type name: string
63 :type name: string
67 :type engine: obj
64 :type engine: obj
68 """
65 """
69 table = _to_table(table, engine)
66 table = _to_table(table, engine)
70 table.rename(name, **kw)
67 table.rename(name, **kw)
71
68
72
69
73 def rename_index(index, name, table=None, engine=None, **kw):
70 def rename_index(index, name, table=None, engine=None, **kw):
74 """Rename an index.
71 """Rename an index.
75
72
76 If Index instance is given,
73 If Index instance is given,
77 table and engine are not used.
74 table and engine are not used.
78
75
79 API to :meth:`ChangesetIndex.rename`.
76 API to :meth:`ChangesetIndex.rename`.
80
77
81 :param index: Index to be renamed.
78 :param index: Index to be renamed.
82 :param name: New name for index.
79 :param name: New name for index.
83 :param table: Table to which Index is reffered.
80 :param table: Table to which Index is reffered.
84 :param engine: Engine instance.
81 :param engine: Engine instance.
85 :type index: string or Index instance
82 :type index: string or Index instance
86 :type name: string
83 :type name: string
87 :type table: string or Table instance
84 :type table: string or Table instance
88 :type engine: obj
85 :type engine: obj
89 """
86 """
90 index = _to_index(index, table, engine)
87 index = _to_index(index, table, engine)
91 index.rename(name, **kw)
88 index.rename(name, **kw)
92
89
93
90
94 def alter_column(*p, **k):
91 def alter_column(*p, **k):
95 """Alter a column.
92 """Alter a column.
96
93
97 This is a helper function that creates a :class:`ColumnDelta` and
94 This is a helper function that creates a :class:`ColumnDelta` and
98 runs it.
95 runs it.
99
96
100 :argument column:
97 :argument column:
101 The name of the column to be altered or a
98 The name of the column to be altered or a
102 :class:`ChangesetColumn` column representing it.
99 :class:`ChangesetColumn` column representing it.
103
100
104 :param table:
101 :param table:
105 A :class:`~sqlalchemy.schema.Table` or table name to
102 A :class:`~sqlalchemy.schema.Table` or table name to
106 for the table where the column will be changed.
103 for the table where the column will be changed.
107
104
108 :param engine:
105 :param engine:
109 The :class:`~sqlalchemy.engine.base.Engine` to use for table
106 The :class:`~sqlalchemy.engine.base.Engine` to use for table
110 reflection and schema alterations.
107 reflection and schema alterations.
111
108
112 :param alter_metadata:
113 If `True`, which is the default, the
114 :class:`~sqlalchemy.schema.Column` will also modified.
115 If `False`, the :class:`~sqlalchemy.schema.Column` will be left
116 as it was.
117
118 :returns: A :class:`ColumnDelta` instance representing the change.
109 :returns: A :class:`ColumnDelta` instance representing the change.
119
110
120
111
121 """
112 """
122
113
123 k.setdefault('alter_metadata', DEFAULT_ALTER_METADATA)
124
125 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
114 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
126 k['table'] = p[0].table
115 k['table'] = p[0].table
127 if 'engine' not in k:
116 if 'engine' not in k:
128 k['engine'] = k['table'].bind
117 k['engine'] = k['table'].bind
129
118
130 # deprecation
119 # deprecation
131 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
120 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
132 warnings.warn(
121 warnings.warn(
133 "Passing a Column object to alter_column is deprecated."
122 "Passing a Column object to alter_column is deprecated."
134 " Just pass in keyword parameters instead.",
123 " Just pass in keyword parameters instead.",
135 MigrateDeprecationWarning
124 MigrateDeprecationWarning
136 )
125 )
137 engine = k['engine']
126 engine = k['engine']
127
128 # enough tests seem to break when metadata is always altered
129 # that this crutch has to be left in until they can be sorted
130 # out
131 k['alter_metadata']=True
132
138 delta = ColumnDelta(*p, **k)
133 delta = ColumnDelta(*p, **k)
139
134
140 visitorcallable = get_engine_visitor(engine, 'schemachanger')
135 visitorcallable = get_engine_visitor(engine, 'schemachanger')
141 engine._run_visitor(visitorcallable, delta)
136 engine._run_visitor(visitorcallable, delta)
142
137
143 return delta
138 return delta
144
139
145
140
146 def _to_table(table, engine=None):
141 def _to_table(table, engine=None):
147 """Return if instance of Table, else construct new with metadata"""
142 """Return if instance of Table, else construct new with metadata"""
148 if isinstance(table, sqlalchemy.Table):
143 if isinstance(table, sqlalchemy.Table):
149 return table
144 return table
150
145
151 # Given: table name, maybe an engine
146 # Given: table name, maybe an engine
152 meta = sqlalchemy.MetaData()
147 meta = sqlalchemy.MetaData()
153 if engine is not None:
148 if engine is not None:
154 meta.bind = engine
149 meta.bind = engine
155 return sqlalchemy.Table(table, meta)
150 return sqlalchemy.Table(table, meta)
156
151
157
152
158 def _to_index(index, table=None, engine=None):
153 def _to_index(index, table=None, engine=None):
159 """Return if instance of Index, else construct new with metadata"""
154 """Return if instance of Index, else construct new with metadata"""
160 if isinstance(index, sqlalchemy.Index):
155 if isinstance(index, sqlalchemy.Index):
161 return index
156 return index
162
157
163 # Given: index name; table name required
158 # Given: index name; table name required
164 table = _to_table(table, engine)
159 table = _to_table(table, engine)
165 ret = sqlalchemy.Index(index)
160 ret = sqlalchemy.Index(index)
166 ret.table = table
161 ret.table = table
167 return ret
162 return ret
168
163
169
164
170 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
165 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
171 """Extracts the differences between two columns/column-parameters
166 """Extracts the differences between two columns/column-parameters
172
167
173 May receive parameters arranged in several different ways:
168 May receive parameters arranged in several different ways:
174
169
175 * **current_column, new_column, \*p, \*\*kw**
170 * **current_column, new_column, \*p, \*\*kw**
176 Additional parameters can be specified to override column
171 Additional parameters can be specified to override column
177 differences.
172 differences.
178
173
179 * **current_column, \*p, \*\*kw**
174 * **current_column, \*p, \*\*kw**
180 Additional parameters alter current_column. Table name is extracted
175 Additional parameters alter current_column. Table name is extracted
181 from current_column object.
176 from current_column object.
182 Name is changed to current_column.name from current_name,
177 Name is changed to current_column.name from current_name,
183 if current_name is specified.
178 if current_name is specified.
184
179
185 * **current_col_name, \*p, \*\*kw**
180 * **current_col_name, \*p, \*\*kw**
186 Table kw must specified.
181 Table kw must specified.
187
182
188 :param table: Table at which current Column should be bound to.\
183 :param table: Table at which current Column should be bound to.\
189 If table name is given, reflection will be used.
184 If table name is given, reflection will be used.
190 :type table: string or Table instance
185 :type table: string or Table instance
191 :param alter_metadata: If True, it will apply changes to metadata.
186
192 :type alter_metadata: bool
187 :param metadata: A :class:`MetaData` instance to store
193 :param metadata: If `alter_metadata` is true, \
188 reflected table names
194 metadata is used to reflect table names into
189
195 :type metadata: :class:`MetaData` instance
196 :param engine: When reflecting tables, either engine or metadata must \
190 :param engine: When reflecting tables, either engine or metadata must \
197 be specified to acquire engine object.
191 be specified to acquire engine object.
198 :type engine: :class:`Engine` instance
192 :type engine: :class:`Engine` instance
199 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
193 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
200 `result_column` through :func:`dict` alike object.
194 `result_column` through :func:`dict` alike object.
201
195
202 * :class:`ColumnDelta`.result_column is altered column with new attributes
196 * :class:`ColumnDelta`.result_column is altered column with new attributes
203
197
204 * :class:`ColumnDelta`.current_name is current name of column in db
198 * :class:`ColumnDelta`.current_name is current name of column in db
205
199
206
200
207 """
201 """
208
202
209 # Column attributes that can be altered
203 # Column attributes that can be altered
210 diff_keys = ('name', 'type', 'primary_key', 'nullable',
204 diff_keys = ('name', 'type', 'primary_key', 'nullable',
211 'server_onupdate', 'server_default', 'autoincrement')
205 'server_onupdate', 'server_default', 'autoincrement')
212 diffs = dict()
206 diffs = dict()
213 __visit_name__ = 'column'
207 __visit_name__ = 'column'
214
208
215 def __init__(self, *p, **kw):
209 def __init__(self, *p, **kw):
210 # 'alter_metadata' is not a public api. It exists purely
211 # as a crutch until the tests that fail when 'alter_metadata'
212 # behaviour always happens can be sorted out
216 self.alter_metadata = kw.pop("alter_metadata", False)
213 self.alter_metadata = kw.pop("alter_metadata", False)
214
217 self.meta = kw.pop("metadata", None)
215 self.meta = kw.pop("metadata", None)
218 self.engine = kw.pop("engine", None)
216 self.engine = kw.pop("engine", None)
219
217
220 # Things are initialized differently depending on how many column
218 # Things are initialized differently depending on how many column
221 # parameters are given. Figure out how many and call the appropriate
219 # parameters are given. Figure out how many and call the appropriate
222 # method.
220 # method.
223 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
221 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
224 # At least one column specified
222 # At least one column specified
225 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
223 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
226 # Two columns specified
224 # Two columns specified
227 diffs = self.compare_2_columns(*p, **kw)
225 diffs = self.compare_2_columns(*p, **kw)
228 else:
226 else:
229 # Exactly one column specified
227 # Exactly one column specified
230 diffs = self.compare_1_column(*p, **kw)
228 diffs = self.compare_1_column(*p, **kw)
231 else:
229 else:
232 # Zero columns specified
230 # Zero columns specified
233 if not len(p) or not isinstance(p[0], basestring):
231 if not len(p) or not isinstance(p[0], basestring):
234 raise ValueError("First argument must be column name")
232 raise ValueError("First argument must be column name")
235 diffs = self.compare_parameters(*p, **kw)
233 diffs = self.compare_parameters(*p, **kw)
236
234
237 self.apply_diffs(diffs)
235 self.apply_diffs(diffs)
238
236
239 def __repr__(self):
237 def __repr__(self):
240 return '<ColumnDelta altermetadata=%r, %s>' % (self.alter_metadata,
238 return '<ColumnDelta altermetadata=%r, %s>' % (
241 super(ColumnDelta, self).__repr__())
239 self.alter_metadata,
240 super(ColumnDelta, self).__repr__()
241 )
242
242
243 def __getitem__(self, key):
243 def __getitem__(self, key):
244 if key not in self.keys():
244 if key not in self.keys():
245 raise KeyError("No such diff key, available: %s" % self.diffs)
245 raise KeyError("No such diff key, available: %s" % self.diffs )
246 return getattr(self.result_column, key)
246 return getattr(self.result_column, key)
247
247
248 def __setitem__(self, key, value):
248 def __setitem__(self, key, value):
249 if key not in self.keys():
249 if key not in self.keys():
250 raise KeyError("No such diff key, available: %s" % self.diffs)
250 raise KeyError("No such diff key, available: %s" % self.diffs )
251 setattr(self.result_column, key, value)
251 setattr(self.result_column, key, value)
252
252
253 def __delitem__(self, key):
253 def __delitem__(self, key):
254 raise NotImplementedError
254 raise NotImplementedError
255
255
256 def keys(self):
256 def keys(self):
257 return self.diffs.keys()
257 return self.diffs.keys()
258
258
259 def compare_parameters(self, current_name, *p, **k):
259 def compare_parameters(self, current_name, *p, **k):
260 """Compares Column objects with reflection"""
260 """Compares Column objects with reflection"""
261 self.table = k.pop('table')
261 self.table = k.pop('table')
262 self.result_column = self._table.c.get(current_name)
262 self.result_column = self._table.c.get(current_name)
263 if len(p):
263 if len(p):
264 k = self._extract_parameters(p, k, self.result_column)
264 k = self._extract_parameters(p, k, self.result_column)
265 return k
265 return k
266
266
267 def compare_1_column(self, col, *p, **k):
267 def compare_1_column(self, col, *p, **k):
268 """Compares one Column object"""
268 """Compares one Column object"""
269 self.table = k.pop('table', None)
269 self.table = k.pop('table', None)
270 if self.table is None:
270 if self.table is None:
271 self.table = col.table
271 self.table = col.table
272 self.result_column = col
272 self.result_column = col
273 if len(p):
273 if len(p):
274 k = self._extract_parameters(p, k, self.result_column)
274 k = self._extract_parameters(p, k, self.result_column)
275 return k
275 return k
276
276
277 def compare_2_columns(self, old_col, new_col, *p, **k):
277 def compare_2_columns(self, old_col, new_col, *p, **k):
278 """Compares two Column objects"""
278 """Compares two Column objects"""
279 self.process_column(new_col)
279 self.process_column(new_col)
280 self.table = k.pop('table', None)
280 self.table = k.pop('table', None)
281 # we cannot use bool() on table in SA06
281 # we cannot use bool() on table in SA06
282 if self.table is None:
282 if self.table is None:
283 self.table = old_col.table
283 self.table = old_col.table
284 if self.table is None:
284 if self.table is None:
285 new_col.table
285 new_col.table
286 self.result_column = old_col
286 self.result_column = old_col
287
287
288 # set differences
288 # set differences
289 # leave out some stuff for later comp
289 # leave out some stuff for later comp
290 for key in (set(self.diff_keys) - set(('type',))):
290 for key in (set(self.diff_keys) - set(('type',))):
291 val = getattr(new_col, key, None)
291 val = getattr(new_col, key, None)
292 if getattr(self.result_column, key, None) != val:
292 if getattr(self.result_column, key, None) != val:
293 k.setdefault(key, val)
293 k.setdefault(key, val)
294
294
295 # inspect types
295 # inspect types
296 if not self.are_column_types_eq(self.result_column.type, new_col.type):
296 if not self.are_column_types_eq(self.result_column.type, new_col.type):
297 k.setdefault('type', new_col.type)
297 k.setdefault('type', new_col.type)
298
298
299 if len(p):
299 if len(p):
300 k = self._extract_parameters(p, k, self.result_column)
300 k = self._extract_parameters(p, k, self.result_column)
301 return k
301 return k
302
302
303 def apply_diffs(self, diffs):
303 def apply_diffs(self, diffs):
304 """Populate dict and column object with new values"""
304 """Populate dict and column object with new values"""
305 self.diffs = diffs
305 self.diffs = diffs
306 for key in self.diff_keys:
306 for key in self.diff_keys:
307 if key in diffs:
307 if key in diffs:
308 setattr(self.result_column, key, diffs[key])
308 setattr(self.result_column, key, diffs[key])
309
309
310 self.process_column(self.result_column)
310 self.process_column(self.result_column)
311
311
312 # create an instance of class type if not yet
312 # create an instance of class type if not yet
313 if 'type' in diffs and callable(self.result_column.type):
313 if 'type' in diffs and callable(self.result_column.type):
314 self.result_column.type = self.result_column.type()
314 self.result_column.type = self.result_column.type()
315
315
316 # add column to the table
316 # add column to the table
317 if self.table is not None and self.alter_metadata:
317 if self.table is not None and self.alter_metadata:
318 self.result_column.add_to_table(self.table)
318 self.result_column.add_to_table(self.table)
319
319
320 def are_column_types_eq(self, old_type, new_type):
320 def are_column_types_eq(self, old_type, new_type):
321 """Compares two types to be equal"""
321 """Compares two types to be equal"""
322 ret = old_type.__class__ == new_type.__class__
322 ret = old_type.__class__ == new_type.__class__
323
323
324 # String length is a special case
324 # String length is a special case
325 if ret and isinstance(new_type, sqlalchemy.types.String):
325 if ret and isinstance(new_type, sqlalchemy.types.String):
326 ret = (getattr(old_type, 'length', None) == \
326 ret = (getattr(old_type, 'length', None) == \
327 getattr(new_type, 'length', None))
327 getattr(new_type, 'length', None))
328 return ret
328 return ret
329
329
330 def _extract_parameters(self, p, k, column):
330 def _extract_parameters(self, p, k, column):
331 """Extracts data from p and modifies diffs"""
331 """Extracts data from p and modifies diffs"""
332 p = list(p)
332 p = list(p)
333 while len(p):
333 while len(p):
334 if isinstance(p[0], basestring):
334 if isinstance(p[0], basestring):
335 k.setdefault('name', p.pop(0))
335 k.setdefault('name', p.pop(0))
336 elif isinstance(p[0], sqlalchemy.types.AbstractType):
336 elif isinstance(p[0], sqlalchemy.types.AbstractType):
337 k.setdefault('type', p.pop(0))
337 k.setdefault('type', p.pop(0))
338 elif callable(p[0]):
338 elif callable(p[0]):
339 p[0] = p[0]()
339 p[0] = p[0]()
340 else:
340 else:
341 break
341 break
342
342
343 if len(p):
343 if len(p):
344 new_col = column.copy_fixed()
344 new_col = column.copy_fixed()
345 new_col._init_items(*p)
345 new_col._init_items(*p)
346 k = self.compare_2_columns(column, new_col, **k)
346 k = self.compare_2_columns(column, new_col, **k)
347 return k
347 return k
348
348
349 def process_column(self, column):
349 def process_column(self, column):
350 """Processes default values for column"""
350 """Processes default values for column"""
351 # XXX: this is a snippet from SA processing of positional parameters
351 # XXX: this is a snippet from SA processing of positional parameters
352 if not SQLA_06 and column.args:
352 if not SQLA_06 and column.args:
353 toinit = list(column.args)
353 toinit = list(column.args)
354 else:
354 else:
355 toinit = list()
355 toinit = list()
356
356
357 if column.server_default is not None:
357 if column.server_default is not None:
358 if isinstance(column.server_default, sqlalchemy.FetchedValue):
358 if isinstance(column.server_default, sqlalchemy.FetchedValue):
359 toinit.append(column.server_default)
359 toinit.append(column.server_default)
360 else:
360 else:
361 toinit.append(sqlalchemy.DefaultClause(column.server_default))
361 toinit.append(sqlalchemy.DefaultClause(column.server_default))
362 if column.server_onupdate is not None:
362 if column.server_onupdate is not None:
363 if isinstance(column.server_onupdate, FetchedValue):
363 if isinstance(column.server_onupdate, FetchedValue):
364 toinit.append(column.server_default)
364 toinit.append(column.server_default)
365 else:
365 else:
366 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
366 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
367 for_update=True))
367 for_update=True))
368 if toinit:
368 if toinit:
369 column._init_items(*toinit)
369 column._init_items(*toinit)
370
370
371 if not SQLA_06:
371 if not SQLA_06:
372 column.args = []
372 column.args = []
373
373
374 def _get_table(self):
374 def _get_table(self):
375 return getattr(self, '_table', None)
375 return getattr(self, '_table', None)
376
376
377 def _set_table(self, table):
377 def _set_table(self, table):
378 if isinstance(table, basestring):
378 if isinstance(table, basestring):
379 if self.alter_metadata:
379 if self.alter_metadata:
380 if not self.meta:
380 if not self.meta:
381 raise ValueError("metadata must be specified for table"
381 raise ValueError("metadata must be specified for table"
382 " reflection when using alter_metadata")
382 " reflection when using alter_metadata")
383 meta = self.meta
383 meta = self.meta
384 if self.engine:
384 if self.engine:
385 meta.bind = self.engine
385 meta.bind = self.engine
386 else:
386 else:
387 if not self.engine and not self.meta:
387 if not self.engine and not self.meta:
388 raise ValueError("engine or metadata must be specified"
388 raise ValueError("engine or metadata must be specified"
389 " to reflect tables")
389 " to reflect tables")
390 if not self.engine:
390 if not self.engine:
391 self.engine = self.meta.bind
391 self.engine = self.meta.bind
392 meta = sqlalchemy.MetaData(bind=self.engine)
392 meta = sqlalchemy.MetaData(bind=self.engine)
393 self._table = sqlalchemy.Table(table, meta, autoload=True)
393 self._table = sqlalchemy.Table(table, meta, autoload=True)
394 elif isinstance(table, sqlalchemy.Table):
394 elif isinstance(table, sqlalchemy.Table):
395 self._table = table
395 self._table = table
396 if not self.alter_metadata:
396 if not self.alter_metadata:
397 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
397 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
398
399 def _get_result_column(self):
398 def _get_result_column(self):
400 return getattr(self, '_result_column', None)
399 return getattr(self, '_result_column', None)
401
400
402 def _set_result_column(self, column):
401 def _set_result_column(self, column):
403 """Set Column to Table based on alter_metadata evaluation."""
402 """Set Column to Table based on alter_metadata evaluation."""
404 self.process_column(column)
403 self.process_column(column)
405 if not hasattr(self, 'current_name'):
404 if not hasattr(self, 'current_name'):
406 self.current_name = column.name
405 self.current_name = column.name
407 if self.alter_metadata:
406 if self.alter_metadata:
408 self._result_column = column
407 self._result_column = column
409 else:
408 else:
410 self._result_column = column.copy_fixed()
409 self._result_column = column.copy_fixed()
411
410
412 table = property(_get_table, _set_table)
411 table = property(_get_table, _set_table)
413 result_column = property(_get_result_column, _set_result_column)
412 result_column = property(_get_result_column, _set_result_column)
414
413
415
414
416 class ChangesetTable(object):
415 class ChangesetTable(object):
417 """Changeset extensions to SQLAlchemy tables."""
416 """Changeset extensions to SQLAlchemy tables."""
418
417
419 def create_column(self, column, *p, **kw):
418 def create_column(self, column, *p, **kw):
420 """Creates a column.
419 """Creates a column.
421
420
422 The column parameter may be a column definition or the name of
421 The column parameter may be a column definition or the name of
423 a column in this table.
422 a column in this table.
424
423
425 API to :meth:`ChangesetColumn.create`
424 API to :meth:`ChangesetColumn.create`
426
425
427 :param column: Column to be created
426 :param column: Column to be created
428 :type column: Column instance or string
427 :type column: Column instance or string
429 """
428 """
430 if not isinstance(column, sqlalchemy.Column):
429 if not isinstance(column, sqlalchemy.Column):
431 # It's a column name
430 # It's a column name
432 column = getattr(self.c, str(column))
431 column = getattr(self.c, str(column))
433 column.create(table=self, *p, **kw)
432 column.create(table=self, *p, **kw)
434
433
435 def drop_column(self, column, *p, **kw):
434 def drop_column(self, column, *p, **kw):
436 """Drop a column, given its name or definition.
435 """Drop a column, given its name or definition.
437
436
438 API to :meth:`ChangesetColumn.drop`
437 API to :meth:`ChangesetColumn.drop`
439
438
440 :param column: Column to be droped
439 :param column: Column to be droped
441 :type column: Column instance or string
440 :type column: Column instance or string
442 """
441 """
443 if not isinstance(column, sqlalchemy.Column):
442 if not isinstance(column, sqlalchemy.Column):
444 # It's a column name
443 # It's a column name
445 try:
444 try:
446 column = getattr(self.c, str(column))
445 column = getattr(self.c, str(column))
447 except AttributeError:
446 except AttributeError:
448 # That column isn't part of the table. We don't need
447 # That column isn't part of the table. We don't need
449 # its entire definition to drop the column, just its
448 # its entire definition to drop the column, just its
450 # name, so create a dummy column with the same name.
449 # name, so create a dummy column with the same name.
451 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
450 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
452 column.drop(table=self, *p, **kw)
451 column.drop(table=self, *p, **kw)
453
452
454 def rename(self, name, connection=None, **kwargs):
453 def rename(self, name, connection=None, **kwargs):
455 """Rename this table.
454 """Rename this table.
456
455
457 :param name: New name of the table.
456 :param name: New name of the table.
458 :type name: string
457 :type name: string
459 :param alter_metadata: If True, table will be removed from metadata
460 :type alter_metadata: bool
461 :param connection: reuse connection istead of creating new one.
458 :param connection: reuse connection istead of creating new one.
462 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
459 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
463 """
460 """
464 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
465 engine = self.bind
461 engine = self.bind
466 self.new_name = name
462 self.new_name = name
467 visitorcallable = get_engine_visitor(engine, 'schemachanger')
463 visitorcallable = get_engine_visitor(engine, 'schemachanger')
468 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
464 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
469
465
470 # Fix metadata registration
466 # Fix metadata registration
471 if self.alter_metadata:
472 self.name = name
467 self.name = name
473 self.deregister()
468 self.deregister()
474 self._set_parent(self.metadata)
469 self._set_parent(self.metadata)
475
470
476 def _meta_key(self):
471 def _meta_key(self):
477 return sqlalchemy.schema._get_table_key(self.name, self.schema)
472 return sqlalchemy.schema._get_table_key(self.name, self.schema)
478
473
479 def deregister(self):
474 def deregister(self):
480 """Remove this table from its metadata"""
475 """Remove this table from its metadata"""
481 key = self._meta_key()
476 key = self._meta_key()
482 meta = self.metadata
477 meta = self.metadata
483 if key in meta.tables:
478 if key in meta.tables:
484 del meta.tables[key]
479 del meta.tables[key]
485
480
486
481
487 class ChangesetColumn(object):
482 class ChangesetColumn(object):
488 """Changeset extensions to SQLAlchemy columns."""
483 """Changeset extensions to SQLAlchemy columns."""
489
484
490 def alter(self, *p, **k):
485 def alter(self, *p, **k):
491 """Makes a call to :func:`alter_column` for the column this
486 """Makes a call to :func:`alter_column` for the column this
492 method is called on.
487 method is called on.
493 """
488 """
494 if 'table' not in k:
489 if 'table' not in k:
495 k['table'] = self.table
490 k['table'] = self.table
496 if 'engine' not in k:
491 if 'engine' not in k:
497 k['engine'] = k['table'].bind
492 k['engine'] = k['table'].bind
498 return alter_column(self, *p, **k)
493 return alter_column(self, *p, **k)
499
494
500 def create(self, table=None, index_name=None, unique_name=None,
495 def create(self, table=None, index_name=None, unique_name=None,
501 primary_key_name=None, populate_default=True, connection=None, **kwargs):
496 primary_key_name=None, populate_default=True, connection=None, **kwargs):
502 """Create this column in the database.
497 """Create this column in the database.
503
498
504 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
499 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
505 for most databases.
500 for most databases.
506
501
507 :param table: Table instance to create on.
502 :param table: Table instance to create on.
508 :param index_name: Creates :class:`ChangesetIndex` on this column.
503 :param index_name: Creates :class:`ChangesetIndex` on this column.
509 :param unique_name: Creates :class:\
504 :param unique_name: Creates :class:\
510 `~migrate.changeset.constraint.UniqueConstraint` on this column.
505 `~migrate.changeset.constraint.UniqueConstraint` on this column.
511 :param primary_key_name: Creates :class:\
506 :param primary_key_name: Creates :class:\
512 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
507 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
513 :param alter_metadata: If True, column will be added to table object.
514 :param populate_default: If True, created column will be \
508 :param populate_default: If True, created column will be \
515 populated with defaults
509 populated with defaults
516 :param connection: reuse connection istead of creating new one.
510 :param connection: reuse connection istead of creating new one.
517 :type table: Table instance
511 :type table: Table instance
518 :type index_name: string
512 :type index_name: string
519 :type unique_name: string
513 :type unique_name: string
520 :type primary_key_name: string
514 :type primary_key_name: string
521 :type alter_metadata: bool
522 :type populate_default: bool
515 :type populate_default: bool
523 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
516 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
524
517
525 :returns: self
518 :returns: self
526 """
519 """
527 self.populate_default = populate_default
520 self.populate_default = populate_default
528 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
529 self.index_name = index_name
521 self.index_name = index_name
530 self.unique_name = unique_name
522 self.unique_name = unique_name
531 self.primary_key_name = primary_key_name
523 self.primary_key_name = primary_key_name
532 for cons in ('index_name', 'unique_name', 'primary_key_name'):
524 for cons in ('index_name', 'unique_name', 'primary_key_name'):
533 self._check_sanity_constraints(cons)
525 self._check_sanity_constraints(cons)
534
526
535 if self.alter_metadata:
536 self.add_to_table(table)
527 self.add_to_table(table)
537 engine = self.table.bind
528 engine = self.table.bind
538 visitorcallable = get_engine_visitor(engine, 'columngenerator')
529 visitorcallable = get_engine_visitor(engine, 'columngenerator')
539 engine._run_visitor(visitorcallable, self, connection, **kwargs)
530 engine._run_visitor(visitorcallable, self, connection, **kwargs)
540
531
541 # TODO: reuse existing connection
532 # TODO: reuse existing connection
542 if self.populate_default and self.default is not None:
533 if self.populate_default and self.default is not None:
543 stmt = table.update().values({self: engine._execute_default(self.default)})
534 stmt = table.update().values({self: engine._execute_default(self.default)})
544 engine.execute(stmt)
535 engine.execute(stmt)
545
536
546 return self
537 return self
547
538
548 def drop(self, table=None, connection=None, **kwargs):
539 def drop(self, table=None, connection=None, **kwargs):
549 """Drop this column from the database, leaving its table intact.
540 """Drop this column from the database, leaving its table intact.
550
541
551 ``ALTER TABLE DROP COLUMN``, for most databases.
542 ``ALTER TABLE DROP COLUMN``, for most databases.
552
543
553 :param alter_metadata: If True, column will be removed from table object.
554 :type alter_metadata: bool
555 :param connection: reuse connection istead of creating new one.
544 :param connection: reuse connection istead of creating new one.
556 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
545 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
557 """
546 """
558 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
559 if table is not None:
547 if table is not None:
560 self.table = table
548 self.table = table
561 engine = self.table.bind
549 engine = self.table.bind
562 if self.alter_metadata:
563 self.remove_from_table(self.table, unset_table=False)
564 visitorcallable = get_engine_visitor(engine, 'columndropper')
550 visitorcallable = get_engine_visitor(engine, 'columndropper')
565 engine._run_visitor(visitorcallable, self, connection, **kwargs)
551 engine._run_visitor(visitorcallable, self, connection, **kwargs)
566 if self.alter_metadata:
552 self.remove_from_table(self.table, unset_table=False)
567 self.table = None
553 self.table = None
568 return self
554 return self
569
555
570 def add_to_table(self, table):
556 def add_to_table(self, table):
571 if table is not None and self.table is None:
557 if table is not None and self.table is None:
572 self._set_parent(table)
558 self._set_parent(table)
573
559
574 def _col_name_in_constraint(self, cons, name):
560 def _col_name_in_constraint(self,cons,name):
575 return False
561 return False
576
562
577 def remove_from_table(self, table, unset_table=True):
563 def remove_from_table(self, table, unset_table=True):
578 # TODO: remove primary keys, constraints, etc
564 # TODO: remove primary keys, constraints, etc
579 if unset_table:
565 if unset_table:
580 self.table = None
566 self.table = None
581
567
582 to_drop = set()
568 to_drop = set()
583 for index in table.indexes:
569 for index in table.indexes:
584 columns = []
570 columns = []
585 for col in index.columns:
571 for col in index.columns:
586 if col.name != self.name:
572 if col.name!=self.name:
587 columns.append(col)
573 columns.append(col)
588 if columns:
574 if columns:
589 index.columns = columns
575 index.columns=columns
590 else:
576 else:
591 to_drop.add(index)
577 to_drop.add(index)
592 table.indexes = table.indexes - to_drop
578 table.indexes = table.indexes - to_drop
593
579
594 to_drop = set()
580 to_drop = set()
595 for cons in table.constraints:
581 for cons in table.constraints:
596 # TODO: deal with other types of constraint
582 # TODO: deal with other types of constraint
597 if isinstance(cons, (ForeignKeyConstraint,
583 if isinstance(cons,(ForeignKeyConstraint,
598 UniqueConstraint)):
584 UniqueConstraint)):
599 for col_name in cons.columns:
585 for col_name in cons.columns:
600 if not isinstance(col_name, basestring):
586 if not isinstance(col_name,basestring):
601 col_name = col_name.name
587 col_name = col_name.name
602 if self.name == col_name:
588 if self.name==col_name:
603 to_drop.add(cons)
589 to_drop.add(cons)
604 table.constraints = table.constraints - to_drop
590 table.constraints = table.constraints - to_drop
605
591
606 if table.c.contains_column(self):
592 if table.c.contains_column(self):
607 table.c.remove(self)
593 table.c.remove(self)
608
594
609 # TODO: this is fixed in 0.6
595 # TODO: this is fixed in 0.6
610 def copy_fixed(self, **kw):
596 def copy_fixed(self, **kw):
611 """Create a copy of this ``Column``, with all attributes."""
597 """Create a copy of this ``Column``, with all attributes."""
612 return sqlalchemy.Column(self.name, self.type, self.default,
598 return sqlalchemy.Column(self.name, self.type, self.default,
613 key=self.key,
599 key=self.key,
614 primary_key=self.primary_key,
600 primary_key=self.primary_key,
615 nullable=self.nullable,
601 nullable=self.nullable,
616 quote=self.quote,
602 quote=self.quote,
617 index=self.index,
603 index=self.index,
618 unique=self.unique,
604 unique=self.unique,
619 onupdate=self.onupdate,
605 onupdate=self.onupdate,
620 autoincrement=self.autoincrement,
606 autoincrement=self.autoincrement,
621 server_default=self.server_default,
607 server_default=self.server_default,
622 server_onupdate=self.server_onupdate,
608 server_onupdate=self.server_onupdate,
623 *[c.copy(**kw) for c in self.constraints])
609 *[c.copy(**kw) for c in self.constraints])
624
610
625 def _check_sanity_constraints(self, name):
611 def _check_sanity_constraints(self, name):
626 """Check if constraints names are correct"""
612 """Check if constraints names are correct"""
627 obj = getattr(self, name)
613 obj = getattr(self, name)
628 if (getattr(self, name[:-5]) and not obj):
614 if (getattr(self, name[:-5]) and not obj):
629 raise InvalidConstraintError("Column.create() accepts index_name,"
615 raise InvalidConstraintError("Column.create() accepts index_name,"
630 " primary_key_name and unique_name to generate constraints")
616 " primary_key_name and unique_name to generate constraints")
631 if not isinstance(obj, basestring) and obj is not None:
617 if not isinstance(obj, basestring) and obj is not None:
632 raise InvalidConstraintError(
618 raise InvalidConstraintError(
633 "%s argument for column must be constraint name" % name)
619 "%s argument for column must be constraint name" % name)
634
620
635
621
636 class ChangesetIndex(object):
622 class ChangesetIndex(object):
637 """Changeset extensions to SQLAlchemy Indexes."""
623 """Changeset extensions to SQLAlchemy Indexes."""
638
624
639 __visit_name__ = 'index'
625 __visit_name__ = 'index'
640
626
641 def rename(self, name, connection=None, **kwargs):
627 def rename(self, name, connection=None, **kwargs):
642 """Change the name of an index.
628 """Change the name of an index.
643
629
644 :param name: New name of the Index.
630 :param name: New name of the Index.
645 :type name: string
631 :type name: string
646 :param alter_metadata: If True, Index object will be altered.
647 :type alter_metadata: bool
648 :param connection: reuse connection istead of creating new one.
632 :param connection: reuse connection istead of creating new one.
649 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
633 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
650 """
634 """
651 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
652 engine = self.table.bind
635 engine = self.table.bind
653 self.new_name = name
636 self.new_name = name
654 visitorcallable = get_engine_visitor(engine, 'schemachanger')
637 visitorcallable = get_engine_visitor(engine, 'schemachanger')
655 engine._run_visitor(visitorcallable, self, connection, **kwargs)
638 engine._run_visitor(visitorcallable, self, connection, **kwargs)
656 if self.alter_metadata:
657 self.name = name
639 self.name = name
658
640
659
641
660 class ChangesetDefaultClause(object):
642 class ChangesetDefaultClause(object):
661 """Implements comparison between :class:`DefaultClause` instances"""
643 """Implements comparison between :class:`DefaultClause` instances"""
662
644
663 def __eq__(self, other):
645 def __eq__(self, other):
664 if isinstance(other, self.__class__):
646 if isinstance(other, self.__class__):
665 if self.arg == other.arg:
647 if self.arg == other.arg:
666 return True
648 return True
667
649
668 def __ne__(self, other):
650 def __ne__(self, other):
669 return not self.__eq__(other)
651 return not self.__eq__(other)
@@ -1,253 +1,253 b''
1 """
1 """
2 Code to generate a Python model from a database or differences
2 Code to generate a Python model from a database or differences
3 between a model and database.
3 between a model and database.
4
4
5 Some of this is borrowed heavily from the AutoCode project at:
5 Some of this is borrowed heavily from the AutoCode project at:
6 http://code.google.com/p/sqlautocode/
6 http://code.google.com/p/sqlautocode/
7 """
7 """
8
8
9 import sys
9 import sys
10 import logging
10 import logging
11
11
12 import sqlalchemy
12 import sqlalchemy
13
13
14 from rhodecode.lib.dbmigrate import migrate
14 from rhodecode.lib.dbmigrate import migrate
15 from rhodecode.lib.dbmigrate.migrate import changeset
15 from rhodecode.lib.dbmigrate.migrate import changeset
16
16
17 log = logging.getLogger(__name__)
17 log = logging.getLogger(__name__)
18 HEADER = """
18 HEADER = """
19 ## File autogenerated by genmodel.py
19 ## File autogenerated by genmodel.py
20
20
21 from sqlalchemy import *
21 from sqlalchemy import *
22 meta = MetaData()
22 meta = MetaData()
23 """
23 """
24
24
25 DECLARATIVE_HEADER = """
25 DECLARATIVE_HEADER = """
26 ## File autogenerated by genmodel.py
26 ## File autogenerated by genmodel.py
27
27
28 from sqlalchemy import *
28 from sqlalchemy import *
29 from sqlalchemy.ext import declarative
29 from sqlalchemy.ext import declarative
30
30
31 Base = declarative.declarative_base()
31 Base = declarative.declarative_base()
32 """
32 """
33
33
34
34
35 class ModelGenerator(object):
35 class ModelGenerator(object):
36
36
37 def __init__(self, diff, engine, declarative=False):
37 def __init__(self, diff, engine, declarative=False):
38 self.diff = diff
38 self.diff = diff
39 self.engine = engine
39 self.engine = engine
40 self.declarative = declarative
40 self.declarative = declarative
41
41
42 def column_repr(self, col):
42 def column_repr(self, col):
43 kwarg = []
43 kwarg = []
44 if col.key != col.name:
44 if col.key != col.name:
45 kwarg.append('key')
45 kwarg.append('key')
46 if col.primary_key:
46 if col.primary_key:
47 col.primary_key = True # otherwise it dumps it as 1
47 col.primary_key = True # otherwise it dumps it as 1
48 kwarg.append('primary_key')
48 kwarg.append('primary_key')
49 if not col.nullable:
49 if not col.nullable:
50 kwarg.append('nullable')
50 kwarg.append('nullable')
51 if col.onupdate:
51 if col.onupdate:
52 kwarg.append('onupdate')
52 kwarg.append('onupdate')
53 if col.default:
53 if col.default:
54 if col.primary_key:
54 if col.primary_key:
55 # I found that PostgreSQL automatically creates a
55 # I found that PostgreSQL automatically creates a
56 # default value for the sequence, but let's not show
56 # default value for the sequence, but let's not show
57 # that.
57 # that.
58 pass
58 pass
59 else:
59 else:
60 kwarg.append('default')
60 kwarg.append('default')
61 ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
61 ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
62
62
63 # crs: not sure if this is good idea, but it gets rid of extra
63 # crs: not sure if this is good idea, but it gets rid of extra
64 # u''
64 # u''
65 name = col.name.encode('utf8')
65 name = col.name.encode('utf8')
66
66
67 type_ = col.type
67 type_ = col.type
68 for cls in col.type.__class__.__mro__:
68 for cls in col.type.__class__.__mro__:
69 if cls.__module__ == 'sqlalchemy.types' and \
69 if cls.__module__ == 'sqlalchemy.types' and \
70 not cls.__name__.isupper():
70 not cls.__name__.isupper():
71 if cls is not type_.__class__:
71 if cls is not type_.__class__:
72 type_ = cls()
72 type_ = cls()
73 break
73 break
74
74
75 data = {
75 data = {
76 'name': name,
76 'name': name,
77 'type': type_,
77 'type': type_,
78 'constraints': ', '.join([repr(cn) for cn in col.constraints]),
78 'constraints': ', '.join([repr(cn) for cn in col.constraints]),
79 'args': ks and ks or ''}
79 'args': ks and ks or ''}
80
80
81 if data['constraints']:
81 if data['constraints']:
82 if data['args']:
82 if data['args']:
83 data['args'] = ',' + data['args']
83 data['args'] = ',' + data['args']
84
84
85 if data['constraints'] or data['args']:
85 if data['constraints'] or data['args']:
86 data['maybeComma'] = ','
86 data['maybeComma'] = ','
87 else:
87 else:
88 data['maybeComma'] = ''
88 data['maybeComma'] = ''
89
89
90 commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
90 commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
91 commonStuff = commonStuff.strip()
91 commonStuff = commonStuff.strip()
92 data['commonStuff'] = commonStuff
92 data['commonStuff'] = commonStuff
93 if self.declarative:
93 if self.declarative:
94 return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
94 return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
95 else:
95 else:
96 return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
96 return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
97
97
98 def getTableDefn(self, table):
98 def getTableDefn(self, table):
99 out = []
99 out = []
100 tableName = table.name
100 tableName = table.name
101 if self.declarative:
101 if self.declarative:
102 out.append("class %(table)s(Base):" % {'table': tableName})
102 out.append("class %(table)s(Base):" % {'table': tableName})
103 out.append(" __tablename__ = '%(table)s'" % {'table': tableName})
103 out.append(" __tablename__ = '%(table)s'" % {'table': tableName})
104 for col in table.columns:
104 for col in table.columns:
105 out.append(" %s" % self.column_repr(col))
105 out.append(" %s" % self.column_repr(col))
106 else:
106 else:
107 out.append("%(table)s = Table('%(table)s', meta," % \
107 out.append("%(table)s = Table('%(table)s', meta," % \
108 {'table': tableName})
108 {'table': tableName})
109 for col in table.columns:
109 for col in table.columns:
110 out.append(" %s," % self.column_repr(col))
110 out.append(" %s," % self.column_repr(col))
111 out.append(")")
111 out.append(")")
112 return out
112 return out
113
113
114 def _get_tables(self, missingA=False, missingB=False, modified=False):
114 def _get_tables(self,missingA=False,missingB=False,modified=False):
115 to_process = []
115 to_process = []
116 for bool_, names, metadata in (
116 for bool_,names,metadata in (
117 (missingA, self.diff.tables_missing_from_A, self.diff.metadataB),
117 (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
118 (missingB, self.diff.tables_missing_from_B, self.diff.metadataA),
118 (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
119 (modified, self.diff.tables_different, self.diff.metadataA),
119 (modified,self.diff.tables_different,self.diff.metadataA),
120 ):
120 ):
121 if bool_:
121 if bool_:
122 for name in names:
122 for name in names:
123 yield metadata.tables.get(name)
123 yield metadata.tables.get(name)
124
124
125 def toPython(self):
125 def toPython(self):
126 """Assume database is current and model is empty."""
126 """Assume database is current and model is empty."""
127 out = []
127 out = []
128 if self.declarative:
128 if self.declarative:
129 out.append(DECLARATIVE_HEADER)
129 out.append(DECLARATIVE_HEADER)
130 else:
130 else:
131 out.append(HEADER)
131 out.append(HEADER)
132 out.append("")
132 out.append("")
133 for table in self._get_tables(missingA=True):
133 for table in self._get_tables(missingA=True):
134 out.extend(self.getTableDefn(table))
134 out.extend(self.getTableDefn(table))
135 out.append("")
135 out.append("")
136 return '\n'.join(out)
136 return '\n'.join(out)
137
137
138 def toUpgradeDowngradePython(self, indent=' '):
138 def toUpgradeDowngradePython(self, indent=' '):
139 ''' Assume model is most current and database is out-of-date. '''
139 ''' Assume model is most current and database is out-of-date. '''
140 decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema',
140 decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema',
141 'meta = MetaData()']
141 'meta = MetaData()']
142 for table in self._get_tables(
142 for table in self._get_tables(
143 missingA=True, missingB=True, modified=True
143 missingA=True,missingB=True,modified=True
144 ):
144 ):
145 decls.extend(self.getTableDefn(table))
145 decls.extend(self.getTableDefn(table))
146
146
147 upgradeCommands, downgradeCommands = [], []
147 upgradeCommands, downgradeCommands = [], []
148 for tableName in self.diff.tables_missing_from_A:
148 for tableName in self.diff.tables_missing_from_A:
149 upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
149 upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
150 downgradeCommands.append("%(table)s.create()" % \
150 downgradeCommands.append("%(table)s.create()" % \
151 {'table': tableName})
151 {'table': tableName})
152 for tableName in self.diff.tables_missing_from_B:
152 for tableName in self.diff.tables_missing_from_B:
153 upgradeCommands.append("%(table)s.create()" % {'table': tableName})
153 upgradeCommands.append("%(table)s.create()" % {'table': tableName})
154 downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
154 downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
155
155
156 for tableName in self.diff.tables_different:
156 for tableName in self.diff.tables_different:
157 dbTable = self.diff.metadataB.tables[tableName]
157 dbTable = self.diff.metadataB.tables[tableName]
158 missingInDatabase, missingInModel, diffDecl = \
158 missingInDatabase, missingInModel, diffDecl = \
159 self.diff.colDiffs[tableName]
159 self.diff.colDiffs[tableName]
160 for col in missingInDatabase:
160 for col in missingInDatabase:
161 upgradeCommands.append('%s.columns[%r].create()' % (
161 upgradeCommands.append('%s.columns[%r].create()' % (
162 modelTable, col.name))
162 modelTable, col.name))
163 downgradeCommands.append('%s.columns[%r].drop()' % (
163 downgradeCommands.append('%s.columns[%r].drop()' % (
164 modelTable, col.name))
164 modelTable, col.name))
165 for col in missingInModel:
165 for col in missingInModel:
166 upgradeCommands.append('%s.columns[%r].drop()' % (
166 upgradeCommands.append('%s.columns[%r].drop()' % (
167 modelTable, col.name))
167 modelTable, col.name))
168 downgradeCommands.append('%s.columns[%r].create()' % (
168 downgradeCommands.append('%s.columns[%r].create()' % (
169 modelTable, col.name))
169 modelTable, col.name))
170 for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
170 for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
171 upgradeCommands.append(
171 upgradeCommands.append(
172 'assert False, "Can\'t alter columns: %s:%s=>%s"',
172 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
173 modelTable, modelCol.name, databaseCol.name)
173 modelTable, modelCol.name, databaseCol.name))
174 downgradeCommands.append(
174 downgradeCommands.append(
175 'assert False, "Can\'t alter columns: %s:%s=>%s"',
175 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
176 modelTable, modelCol.name, databaseCol.name)
176 modelTable, modelCol.name, databaseCol.name))
177 pre_command = ' meta.bind = migrate_engine'
177 pre_command = ' meta.bind = migrate_engine'
178
178
179 return (
179 return (
180 '\n'.join(decls),
180 '\n'.join(decls),
181 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
181 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
182 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
182 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
183
183
184 def _db_can_handle_this_change(self, td):
184 def _db_can_handle_this_change(self,td):
185 if (td.columns_missing_from_B
185 if (td.columns_missing_from_B
186 and not td.columns_missing_from_A
186 and not td.columns_missing_from_A
187 and not td.columns_different):
187 and not td.columns_different):
188 # Even sqlite can handle this.
188 # Even sqlite can handle this.
189 return True
189 return True
190 else:
190 else:
191 return not self.engine.url.drivername.startswith('sqlite')
191 return not self.engine.url.drivername.startswith('sqlite')
192
192
193 def applyModel(self):
193 def applyModel(self):
194 """Apply model to current database."""
194 """Apply model to current database."""
195
195
196 meta = sqlalchemy.MetaData(self.engine)
196 meta = sqlalchemy.MetaData(self.engine)
197
197
198 for table in self._get_tables(missingA=True):
198 for table in self._get_tables(missingA=True):
199 table = table.tometadata(meta)
199 table = table.tometadata(meta)
200 table.drop()
200 table.drop()
201 for table in self._get_tables(missingB=True):
201 for table in self._get_tables(missingB=True):
202 table = table.tometadata(meta)
202 table = table.tometadata(meta)
203 table.create()
203 table.create()
204 for modelTable in self._get_tables(modified=True):
204 for modelTable in self._get_tables(modified=True):
205 tableName = modelTable.name
205 tableName = modelTable.name
206 modelTable = modelTable.tometadata(meta)
206 modelTable = modelTable.tometadata(meta)
207 dbTable = self.diff.metadataB.tables[tableName]
207 dbTable = self.diff.metadataB.tables[tableName]
208
208
209 td = self.diff.tables_different[tableName]
209 td = self.diff.tables_different[tableName]
210
210
211 if self._db_can_handle_this_change(td):
211 if self._db_can_handle_this_change(td):
212
212
213 for col in td.columns_missing_from_B:
213 for col in td.columns_missing_from_B:
214 modelTable.columns[col].create()
214 modelTable.columns[col].create()
215 for col in td.columns_missing_from_A:
215 for col in td.columns_missing_from_A:
216 dbTable.columns[col].drop()
216 dbTable.columns[col].drop()
217 # XXX handle column changes here.
217 # XXX handle column changes here.
218 else:
218 else:
219 # Sqlite doesn't support drop column, so you have to
219 # Sqlite doesn't support drop column, so you have to
220 # do more: create temp table, copy data to it, drop
220 # do more: create temp table, copy data to it, drop
221 # old table, create new table, copy data back.
221 # old table, create new table, copy data back.
222 #
222 #
223 # I wonder if this is guaranteed to be unique?
223 # I wonder if this is guaranteed to be unique?
224 tempName = '_temp_%s' % modelTable.name
224 tempName = '_temp_%s' % modelTable.name
225
225
226 def getCopyStatement():
226 def getCopyStatement():
227 preparer = self.engine.dialect.preparer
227 preparer = self.engine.dialect.preparer
228 commonCols = []
228 commonCols = []
229 for modelCol in modelTable.columns:
229 for modelCol in modelTable.columns:
230 if modelCol.name in dbTable.columns:
230 if modelCol.name in dbTable.columns:
231 commonCols.append(modelCol.name)
231 commonCols.append(modelCol.name)
232 commonColsStr = ', '.join(commonCols)
232 commonColsStr = ', '.join(commonCols)
233 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
233 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
234 (tableName, commonColsStr, commonColsStr, tempName)
234 (tableName, commonColsStr, commonColsStr, tempName)
235
235
236 # Move the data in one transaction, so that we don't
236 # Move the data in one transaction, so that we don't
237 # leave the database in a nasty state.
237 # leave the database in a nasty state.
238 connection = self.engine.connect()
238 connection = self.engine.connect()
239 trans = connection.begin()
239 trans = connection.begin()
240 try:
240 try:
241 connection.execute(
241 connection.execute(
242 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
242 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
243 (tempName, modelTable.name))
243 (tempName, modelTable.name))
244 # make sure the drop takes place inside our
244 # make sure the drop takes place inside our
245 # transaction with the bind parameter
245 # transaction with the bind parameter
246 modelTable.drop(bind=connection)
246 modelTable.drop(bind=connection)
247 modelTable.create(bind=connection)
247 modelTable.create(bind=connection)
248 connection.execute(getCopyStatement())
248 connection.execute(getCopyStatement())
249 connection.execute('DROP TABLE %s' % tempName)
249 connection.execute('DROP TABLE %s' % tempName)
250 trans.commit()
250 trans.commit()
251 except:
251 except:
252 trans.rollback()
252 trans.rollback()
253 raise
253 raise
@@ -1,159 +1,160 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3
3
4 import shutil
4 import shutil
5 import warnings
5 import warnings
6 import logging
6 import logging
7 import inspect
7 from StringIO import StringIO
8 from StringIO import StringIO
8
9
9 from rhodecode.lib.dbmigrate import migrate
10 from rhodecode.lib.dbmigrate import migrate
10 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
11 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
11 from rhodecode.lib.dbmigrate.migrate.versioning.config import operations
12 from rhodecode.lib.dbmigrate.migrate.versioning.config import operations
12 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
13 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
13 from rhodecode.lib.dbmigrate.migrate.versioning.script import base
14 from rhodecode.lib.dbmigrate.migrate.versioning.script import base
14 from rhodecode.lib.dbmigrate.migrate.versioning.util import import_path, load_model, with_engine
15 from rhodecode.lib.dbmigrate.migrate.versioning.util import import_path, load_model, with_engine
15 from rhodecode.lib.dbmigrate.migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
16 from rhodecode.lib.dbmigrate.migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
16
17
17 log = logging.getLogger(__name__)
18 log = logging.getLogger(__name__)
18 __all__ = ['PythonScript']
19 __all__ = ['PythonScript']
19
20
20
21
21 class PythonScript(base.BaseScript):
22 class PythonScript(base.BaseScript):
22 """Base for Python scripts"""
23 """Base for Python scripts"""
23
24
24 @classmethod
25 @classmethod
25 def create(cls, path, **opts):
26 def create(cls, path, **opts):
26 """Create an empty migration script at specified path
27 """Create an empty migration script at specified path
27
28
28 :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
29 :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
29 cls.require_notfound(path)
30 cls.require_notfound(path)
30
31
31 src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
32 src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
32 shutil.copy(src, path)
33 shutil.copy(src, path)
33
34
34 return cls(path)
35 return cls(path)
35
36
36 @classmethod
37 @classmethod
37 def make_update_script_for_model(cls, engine, oldmodel,
38 def make_update_script_for_model(cls, engine, oldmodel,
38 model, repository, **opts):
39 model, repository, **opts):
39 """Create a migration script based on difference between two SA models.
40 """Create a migration script based on difference between two SA models.
40
41
41 :param repository: path to migrate repository
42 :param repository: path to migrate repository
42 :param oldmodel: dotted.module.name:SAClass or SAClass object
43 :param oldmodel: dotted.module.name:SAClass or SAClass object
43 :param model: dotted.module.name:SAClass or SAClass object
44 :param model: dotted.module.name:SAClass or SAClass object
44 :param engine: SQLAlchemy engine
45 :param engine: SQLAlchemy engine
45 :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
46 :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
46 :type oldmodel: string or Class
47 :type oldmodel: string or Class
47 :type model: string or Class
48 :type model: string or Class
48 :type engine: Engine instance
49 :type engine: Engine instance
49 :returns: Upgrade / Downgrade script
50 :returns: Upgrade / Downgrade script
50 :rtype: string
51 :rtype: string
51 """
52 """
52
53
53 if isinstance(repository, basestring):
54 if isinstance(repository, basestring):
54 # oh dear, an import cycle!
55 # oh dear, an import cycle!
55 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
56 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
56 repository = Repository(repository)
57 repository = Repository(repository)
57
58
58 oldmodel = load_model(oldmodel)
59 oldmodel = load_model(oldmodel)
59 model = load_model(model)
60 model = load_model(model)
60
61
61 # Compute differences.
62 # Compute differences.
62 diff = schemadiff.getDiffOfModelAgainstModel(
63 diff = schemadiff.getDiffOfModelAgainstModel(
63 oldmodel,
64 oldmodel,
64 model,
65 model,
65 excludeTables=[repository.version_table])
66 excludeTables=[repository.version_table])
66 # TODO: diff can be False (there is no difference?)
67 # TODO: diff can be False (there is no difference?)
67 decls, upgradeCommands, downgradeCommands = \
68 decls, upgradeCommands, downgradeCommands = \
68 genmodel.ModelGenerator(diff, engine).toUpgradeDowngradePython()
69 genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython()
69
70
70 # Store differences into file.
71 # Store differences into file.
71 src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
72 src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
72 f = open(src)
73 f = open(src)
73 contents = f.read()
74 contents = f.read()
74 f.close()
75 f.close()
75
76
76 # generate source
77 # generate source
77 search = 'def upgrade(migrate_engine):'
78 search = 'def upgrade(migrate_engine):'
78 contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
79 contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
79 if upgradeCommands:
80 if upgradeCommands:
80 contents = contents.replace(' pass', upgradeCommands, 1)
81 contents = contents.replace(' pass', upgradeCommands, 1)
81 if downgradeCommands:
82 if downgradeCommands:
82 contents = contents.replace(' pass', downgradeCommands, 1)
83 contents = contents.replace(' pass', downgradeCommands, 1)
83 return contents
84 return contents
84
85
85 @classmethod
86 @classmethod
86 def verify_module(cls, path):
87 def verify_module(cls, path):
87 """Ensure path is a valid script
88 """Ensure path is a valid script
88
89
89 :param path: Script location
90 :param path: Script location
90 :type path: string
91 :type path: string
91 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
92 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
92 :returns: Python module
93 :returns: Python module
93 """
94 """
94 # Try to import and get the upgrade() func
95 # Try to import and get the upgrade() func
95 module = import_path(path)
96 module = import_path(path)
96 try:
97 try:
97 assert callable(module.upgrade)
98 assert callable(module.upgrade)
98 except Exception, e:
99 except Exception, e:
99 raise InvalidScriptError(path + ': %s' % str(e))
100 raise InvalidScriptError(path + ': %s' % str(e))
100 return module
101 return module
101
102
102 def preview_sql(self, url, step, **args):
103 def preview_sql(self, url, step, **args):
103 """Mocks SQLAlchemy Engine to store all executed calls in a string
104 """Mocks SQLAlchemy Engine to store all executed calls in a string
104 and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
105 and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
105
106
106 :returns: SQL file
107 :returns: SQL file
107 """
108 """
108 buf = StringIO()
109 buf = StringIO()
109 args['engine_arg_strategy'] = 'mock'
110 args['engine_arg_strategy'] = 'mock'
110 args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
111 args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
111
112
112 @with_engine
113 @with_engine
113 def go(url, step, **kw):
114 def go(url, step, **kw):
114 engine = kw.pop('engine')
115 engine = kw.pop('engine')
115 self.run(engine, step)
116 self.run(engine, step)
116 return buf.getvalue()
117 return buf.getvalue()
117
118
118 return go(url, step, **args)
119 return go(url, step, **args)
119
120
120 def run(self, engine, step):
121 def run(self, engine, step):
121 """Core method of Script file.
122 """Core method of Script file.
122 Exectues :func:`update` or :func:`downgrade` functions
123 Exectues :func:`update` or :func:`downgrade` functions
123
124
124 :param engine: SQLAlchemy Engine
125 :param engine: SQLAlchemy Engine
125 :param step: Operation to run
126 :param step: Operation to run
126 :type engine: string
127 :type engine: string
127 :type step: int
128 :type step: int
128 """
129 """
129 if step > 0:
130 if step > 0:
130 op = 'upgrade'
131 op = 'upgrade'
131 elif step < 0:
132 elif step < 0:
132 op = 'downgrade'
133 op = 'downgrade'
133 else:
134 else:
134 raise ScriptError("%d is not a valid step" % step)
135 raise ScriptError("%d is not a valid step" % step)
135
136
136 funcname = base.operations[op]
137 funcname = base.operations[op]
137 script_func = self._func(funcname)
138 script_func = self._func(funcname)
138
139
139 try:
140 # check for old way of using engine
141 if not inspect.getargspec(script_func)[0]:
142 raise TypeError("upgrade/downgrade functions must accept engine"
143 " parameter (since version 0.5.4)")
144
140 script_func(engine)
145 script_func(engine)
141 except TypeError:
142 warnings.warn("upgrade/downgrade functions must accept engine"
143 " parameter (since version > 0.5.4)", MigrateDeprecationWarning)
144 raise
145
146
146 @property
147 @property
147 def module(self):
148 def module(self):
148 """Calls :meth:`migrate.versioning.script.py.verify_module`
149 """Calls :meth:`migrate.versioning.script.py.verify_module`
149 and returns it.
150 and returns it.
150 """
151 """
151 if not hasattr(self, '_module'):
152 if not hasattr(self, '_module'):
152 self._module = self.verify_module(self.path)
153 self._module = self.verify_module(self.path)
153 return self._module
154 return self._module
154
155
155 def _func(self, funcname):
156 def _func(self, funcname):
156 if not hasattr(self.module, funcname):
157 if not hasattr(self.module, funcname):
157 msg = "Function '%s' is not defined in this script"
158 msg = "Function '%s' is not defined in this script"
158 raise ScriptError(msg % funcname)
159 raise ScriptError(msg % funcname)
159 return getattr(self.module, funcname)
160 return getattr(self.module, funcname)
@@ -1,48 +1,49 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3 import logging
3 import logging
4 import shutil
4 import shutil
5
5
6 from rhodecode.lib.dbmigrate.migrate.versioning.script import base
6 from rhodecode.lib.dbmigrate.migrate.versioning.script import base
7 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
7 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
8
8
9
9
10 log = logging.getLogger(__name__)
10 log = logging.getLogger(__name__)
11
11
12 class SqlScript(base.BaseScript):
12 class SqlScript(base.BaseScript):
13 """A file containing plain SQL statements."""
13 """A file containing plain SQL statements."""
14
14
15 @classmethod
15 @classmethod
16 def create(cls, path, **opts):
16 def create(cls, path, **opts):
17 """Create an empty migration script at specified path
17 """Create an empty migration script at specified path
18
18
19 :returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
19 :returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
20 cls.require_notfound(path)
20 cls.require_notfound(path)
21
21 src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None))
22 src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None))
22 shutil.copy(src, path)
23 shutil.copy(src, path)
23 return cls(path)
24 return cls(path)
24
25
25 # TODO: why is step parameter even here?
26 # TODO: why is step parameter even here?
26 def run(self, engine, step=None, executemany=True):
27 def run(self, engine, step=None, executemany=True):
27 """Runs SQL script through raw dbapi execute call"""
28 """Runs SQL script through raw dbapi execute call"""
28 text = self.source()
29 text = self.source()
29 # Don't rely on SA's autocommit here
30 # Don't rely on SA's autocommit here
30 # (SA uses .startswith to check if a commit is needed. What if script
31 # (SA uses .startswith to check if a commit is needed. What if script
31 # starts with a comment?)
32 # starts with a comment?)
32 conn = engine.connect()
33 conn = engine.connect()
33 try:
34 try:
34 trans = conn.begin()
35 trans = conn.begin()
35 try:
36 try:
36 # HACK: SQLite doesn't allow multiple statements through
37 # HACK: SQLite doesn't allow multiple statements through
37 # its execute() method, but it provides executescript() instead
38 # its execute() method, but it provides executescript() instead
38 dbapi = conn.engine.raw_connection()
39 dbapi = conn.engine.raw_connection()
39 if executemany and getattr(dbapi, 'executescript', None):
40 if executemany and getattr(dbapi, 'executescript', None):
40 dbapi.executescript(text)
41 dbapi.executescript(text)
41 else:
42 else:
42 conn.execute(text)
43 conn.execute(text)
43 trans.commit()
44 trans.commit()
44 except:
45 except:
45 trans.rollback()
46 trans.rollback()
46 raise
47 raise
47 finally:
48 finally:
48 conn.close()
49 conn.close()
@@ -1,215 +1,214 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3
3
4 """The migrate command-line tool."""
4 """The migrate command-line tool."""
5
5
6 import sys
6 import sys
7 import inspect
7 import inspect
8 import logging
8 import logging
9 from optparse import OptionParser, BadOptionError
9 from optparse import OptionParser, BadOptionError
10
10
11 from rhodecode.lib.dbmigrate.migrate import exceptions
11 from rhodecode.lib.dbmigrate.migrate import exceptions
12 from rhodecode.lib.dbmigrate.migrate.versioning import api
12 from rhodecode.lib.dbmigrate.migrate.versioning import api
13 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
13 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
14 from rhodecode.lib.dbmigrate.migrate.versioning.util import asbool
14 from rhodecode.lib.dbmigrate.migrate.versioning.util import asbool
15
15
16
16
17 alias = dict(
17 alias = dict(
18 s=api.script,
18 s=api.script,
19 vc=api.version_control,
19 vc=api.version_control,
20 dbv=api.db_version,
20 dbv=api.db_version,
21 v=api.version,
21 v=api.version,
22 )
22 )
23
23
24 def alias_setup():
24 def alias_setup():
25 global alias
25 global alias
26 for key, val in alias.iteritems():
26 for key, val in alias.iteritems():
27 setattr(api, key, val)
27 setattr(api, key, val)
28 alias_setup()
28 alias_setup()
29
29
30
30
31 class PassiveOptionParser(OptionParser):
31 class PassiveOptionParser(OptionParser):
32
32
33 def _process_args(self, largs, rargs, values):
33 def _process_args(self, largs, rargs, values):
34 """little hack to support all --some_option=value parameters"""
34 """little hack to support all --some_option=value parameters"""
35
35
36 while rargs:
36 while rargs:
37 arg = rargs[0]
37 arg = rargs[0]
38 if arg == "--":
38 if arg == "--":
39 del rargs[0]
39 del rargs[0]
40 return
40 return
41 elif arg[0:2] == "--":
41 elif arg[0:2] == "--":
42 # if parser does not know about the option
42 # if parser does not know about the option
43 # pass it along (make it anonymous)
43 # pass it along (make it anonymous)
44 try:
44 try:
45 opt = arg.split('=', 1)[0]
45 opt = arg.split('=', 1)[0]
46 self._match_long_opt(opt)
46 self._match_long_opt(opt)
47 except BadOptionError:
47 except BadOptionError:
48 largs.append(arg)
48 largs.append(arg)
49 del rargs[0]
49 del rargs[0]
50 else:
50 else:
51 self._process_long_opt(rargs, values)
51 self._process_long_opt(rargs, values)
52 elif arg[:1] == "-" and len(arg) > 1:
52 elif arg[:1] == "-" and len(arg) > 1:
53 self._process_short_opts(rargs, values)
53 self._process_short_opts(rargs, values)
54 elif self.allow_interspersed_args:
54 elif self.allow_interspersed_args:
55 largs.append(arg)
55 largs.append(arg)
56 del rargs[0]
56 del rargs[0]
57
57
58 def main(argv=None, **kwargs):
58 def main(argv=None, **kwargs):
59 """Shell interface to :mod:`migrate.versioning.api`.
59 """Shell interface to :mod:`migrate.versioning.api`.
60
60
61 kwargs are default options that can be overriden with passing
61 kwargs are default options that can be overriden with passing
62 --some_option as command line option
62 --some_option as command line option
63
63
64 :param disable_logging: Let migrate configure logging
64 :param disable_logging: Let migrate configure logging
65 :type disable_logging: bool
65 :type disable_logging: bool
66 """
66 """
67 if argv is not None:
67 if argv is not None:
68 argv = argv
68 argv = argv
69 else:
69 else:
70 argv = list(sys.argv[1:])
70 argv = list(sys.argv[1:])
71 commands = list(api.__all__)
71 commands = list(api.__all__)
72 commands.sort()
72 commands.sort()
73
73
74 usage = """%%prog COMMAND ...
74 usage = """%%prog COMMAND ...
75
75
76 Available commands:
76 Available commands:
77 %s
77 %s
78
78
79 Enter "%%prog help COMMAND" for information on a particular command.
79 Enter "%%prog help COMMAND" for information on a particular command.
80 """ % '\n\t'.join(["%s - %s" % (command.ljust(28),
80 """ % '\n\t'.join(["%s - %s" % (command.ljust(28), api.command_desc.get(command)) for command in commands])
81 api.command_desc.get(command)) for command in commands])
82
81
83 parser = PassiveOptionParser(usage=usage)
82 parser = PassiveOptionParser(usage=usage)
84 parser.add_option("-d", "--debug",
83 parser.add_option("-d", "--debug",
85 action="store_true",
84 action="store_true",
86 dest="debug",
85 dest="debug",
87 default=False,
86 default=False,
88 help="Shortcut to turn on DEBUG mode for logging")
87 help="Shortcut to turn on DEBUG mode for logging")
89 parser.add_option("-q", "--disable_logging",
88 parser.add_option("-q", "--disable_logging",
90 action="store_true",
89 action="store_true",
91 dest="disable_logging",
90 dest="disable_logging",
92 default=False,
91 default=False,
93 help="Use this option to disable logging configuration")
92 help="Use this option to disable logging configuration")
94 help_commands = ['help', '-h', '--help']
93 help_commands = ['help', '-h', '--help']
95 HELP = False
94 HELP = False
96
95
97 try:
96 try:
98 command = argv.pop(0)
97 command = argv.pop(0)
99 if command in help_commands:
98 if command in help_commands:
100 HELP = True
99 HELP = True
101 command = argv.pop(0)
100 command = argv.pop(0)
102 except IndexError:
101 except IndexError:
103 parser.print_help()
102 parser.print_help()
104 return
103 return
105
104
106 command_func = getattr(api, command, None)
105 command_func = getattr(api, command, None)
107 if command_func is None or command.startswith('_'):
106 if command_func is None or command.startswith('_'):
108 parser.error("Invalid command %s" % command)
107 parser.error("Invalid command %s" % command)
109
108
110 parser.set_usage(inspect.getdoc(command_func))
109 parser.set_usage(inspect.getdoc(command_func))
111 f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
110 f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
112 for arg in f_args:
111 for arg in f_args:
113 parser.add_option(
112 parser.add_option(
114 "--%s" % arg,
113 "--%s" % arg,
115 dest=arg,
114 dest=arg,
116 action='store',
115 action='store',
117 type="string")
116 type="string")
118
117
119 # display help of the current command
118 # display help of the current command
120 if HELP:
119 if HELP:
121 parser.print_help()
120 parser.print_help()
122 return
121 return
123
122
124 options, args = parser.parse_args(argv)
123 options, args = parser.parse_args(argv)
125
124
126 # override kwargs with anonymous parameters
125 # override kwargs with anonymous parameters
127 override_kwargs = dict()
126 override_kwargs = dict()
128 for arg in list(args):
127 for arg in list(args):
129 if arg.startswith('--'):
128 if arg.startswith('--'):
130 args.remove(arg)
129 args.remove(arg)
131 if '=' in arg:
130 if '=' in arg:
132 opt, value = arg[2:].split('=', 1)
131 opt, value = arg[2:].split('=', 1)
133 else:
132 else:
134 opt = arg[2:]
133 opt = arg[2:]
135 value = True
134 value = True
136 override_kwargs[opt] = value
135 override_kwargs[opt] = value
137
136
138 # override kwargs with options if user is overwriting
137 # override kwargs with options if user is overwriting
139 for key, value in options.__dict__.iteritems():
138 for key, value in options.__dict__.iteritems():
140 if value is not None:
139 if value is not None:
141 override_kwargs[key] = value
140 override_kwargs[key] = value
142
141
143 # arguments that function accepts without passed kwargs
142 # arguments that function accepts without passed kwargs
144 f_required = list(f_args)
143 f_required = list(f_args)
145 candidates = dict(kwargs)
144 candidates = dict(kwargs)
146 candidates.update(override_kwargs)
145 candidates.update(override_kwargs)
147 for key, value in candidates.iteritems():
146 for key, value in candidates.iteritems():
148 if key in f_args:
147 if key in f_args:
149 f_required.remove(key)
148 f_required.remove(key)
150
149
151 # map function arguments to parsed arguments
150 # map function arguments to parsed arguments
152 for arg in args:
151 for arg in args:
153 try:
152 try:
154 kw = f_required.pop(0)
153 kw = f_required.pop(0)
155 except IndexError:
154 except IndexError:
156 parser.error("Too many arguments for command %s: %s" % (command,
155 parser.error("Too many arguments for command %s: %s" % (command,
157 arg))
156 arg))
158 kwargs[kw] = arg
157 kwargs[kw] = arg
159
158
160 # apply overrides
159 # apply overrides
161 kwargs.update(override_kwargs)
160 kwargs.update(override_kwargs)
162
161
163 # configure options
162 # configure options
164 for key, value in options.__dict__.iteritems():
163 for key, value in options.__dict__.iteritems():
165 kwargs.setdefault(key, value)
164 kwargs.setdefault(key, value)
166
165
167 # configure logging
166 # configure logging
168 if not asbool(kwargs.pop('disable_logging', False)):
167 if not asbool(kwargs.pop('disable_logging', False)):
169 # filter to log =< INFO into stdout and rest to stderr
168 # filter to log =< INFO into stdout and rest to stderr
170 class SingleLevelFilter(logging.Filter):
169 class SingleLevelFilter(logging.Filter):
171 def __init__(self, min=None, max=None):
170 def __init__(self, min=None, max=None):
172 self.min = min or 0
171 self.min = min or 0
173 self.max = max or 100
172 self.max = max or 100
174
173
175 def filter(self, record):
174 def filter(self, record):
176 return self.min <= record.levelno <= self.max
175 return self.min <= record.levelno <= self.max
177
176
178 logger = logging.getLogger()
177 logger = logging.getLogger()
179 h1 = logging.StreamHandler(sys.stdout)
178 h1 = logging.StreamHandler(sys.stdout)
180 f1 = SingleLevelFilter(max=logging.INFO)
179 f1 = SingleLevelFilter(max=logging.INFO)
181 h1.addFilter(f1)
180 h1.addFilter(f1)
182 h2 = logging.StreamHandler(sys.stderr)
181 h2 = logging.StreamHandler(sys.stderr)
183 f2 = SingleLevelFilter(min=logging.WARN)
182 f2 = SingleLevelFilter(min=logging.WARN)
184 h2.addFilter(f2)
183 h2.addFilter(f2)
185 logger.addHandler(h1)
184 logger.addHandler(h1)
186 logger.addHandler(h2)
185 logger.addHandler(h2)
187
186
188 if options.debug:
187 if options.debug:
189 logger.setLevel(logging.DEBUG)
188 logger.setLevel(logging.DEBUG)
190 else:
189 else:
191 logger.setLevel(logging.INFO)
190 logger.setLevel(logging.INFO)
192
191
193 log = logging.getLogger(__name__)
192 log = logging.getLogger(__name__)
194
193
195 # check if all args are given
194 # check if all args are given
196 try:
195 try:
197 num_defaults = len(f_defaults)
196 num_defaults = len(f_defaults)
198 except TypeError:
197 except TypeError:
199 num_defaults = 0
198 num_defaults = 0
200 f_args_default = f_args[len(f_args) - num_defaults:]
199 f_args_default = f_args[len(f_args) - num_defaults:]
201 required = list(set(f_required) - set(f_args_default))
200 required = list(set(f_required) - set(f_args_default))
202 if required:
201 if required:
203 parser.error("Not enough arguments for command %s: %s not specified" \
202 parser.error("Not enough arguments for command %s: %s not specified" \
204 % (command, ', '.join(required)))
203 % (command, ', '.join(required)))
205
204
206 # handle command
205 # handle command
207 try:
206 try:
208 ret = command_func(**kwargs)
207 ret = command_func(**kwargs)
209 if ret is not None:
208 if ret is not None:
210 log.info(ret)
209 log.info(ret)
211 except (exceptions.UsageError, exceptions.KnownError), e:
210 except (exceptions.UsageError, exceptions.KnownError), e:
212 parser.error(e.args[0])
211 parser.error(e.args[0])
213
212
214 if __name__ == "__main__":
213 if __name__ == "__main__":
215 main()
214 main()
General Comments 0
You need to be logged in to leave comments. Login now