##// END OF EJS Templates
dbmigrate: 2to3 pass with fixes
super-admin -
r4988:ff25f201 default
parent child Browse files
Show More
@@ -1,314 +1,314 b''
1 1 """
2 2 Extensions to SQLAlchemy for altering existing tables.
3 3
4 4 At the moment, this isn't so much based off of ANSI as much as
5 5 things that just happen to work with multiple databases.
6 6 """
7 from io import StringIO
7 import io
8 8
9 9 import sqlalchemy as sa
10 10 from sqlalchemy.schema import SchemaVisitor
11 11 from sqlalchemy.engine.default import DefaultDialect
12 12 from sqlalchemy.sql import ClauseElement
13 13 from sqlalchemy.schema import (ForeignKeyConstraint,
14 14 PrimaryKeyConstraint,
15 15 CheckConstraint,
16 16 UniqueConstraint,
17 17 Index)
18 18
19 19 import sqlalchemy.sql.compiler
20 20 from rhodecode.lib.dbmigrate.migrate import exceptions
21 21 from rhodecode.lib.dbmigrate.migrate.changeset import constraint
22 22 from rhodecode.lib.dbmigrate.migrate.changeset import util
23 23
24 24 from sqlalchemy.schema import AddConstraint, DropConstraint
25 25 from sqlalchemy.sql.compiler import DDLCompiler
26 26 SchemaGenerator = SchemaDropper = DDLCompiler
27 27
28 28
29 29 class AlterTableVisitor(SchemaVisitor):
30 30 """Common operations for ``ALTER TABLE`` statements."""
31 31
32 32 # engine.Compiler looks for .statement
33 33 # when it spawns off a new compiler
34 34 statement = ClauseElement()
35 35
36 36 def append(self, s):
37 37 """Append content to the SchemaIterator's query buffer."""
38 38
39 39 self.buffer.write(s)
40 40
41 41 def execute(self):
42 42 """Execute the contents of the SchemaIterator's buffer."""
43 43 try:
44 44 return self.connection.execute(self.buffer.getvalue())
45 45 finally:
46 46 self.buffer.seek(0)
47 47 self.buffer.truncate()
48 48
49 49 def __init__(self, dialect, connection, **kw):
50 50 self.connection = connection
51 self.buffer = StringIO.StringIO()
51 self.buffer = io.StringIO()
52 52 self.preparer = dialect.identifier_preparer
53 53 self.dialect = dialect
54 54
55 55 def traverse_single(self, elem):
56 56 ret = super(AlterTableVisitor, self).traverse_single(elem)
57 57 if ret:
58 58 # adapt to 0.6 which uses a string-returning
59 59 # object
60 60 self.append(" %s" % ret)
61 61
62 62 def _to_table(self, param):
63 63 """Returns the table object for the given param object."""
64 64 if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
65 65 ret = param.table
66 66 else:
67 67 ret = param
68 68 return ret
69 69
70 70 def start_alter_table(self, param):
71 71 """Returns the start of an ``ALTER TABLE`` SQL-Statement.
72 72
73 73 Use the param object to determine the table name and use it
74 74 for building the SQL statement.
75 75
76 76 :param param: object to determine the table from
77 77 :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
78 78 :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
79 79 or string (table name)
80 80 """
81 81 table = self._to_table(param)
82 82 self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
83 83 return table
84 84
85 85
86 86 class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
87 87 """Extends ansisql generator for column creation (alter table add col)"""
88 88
89 89 def visit_column(self, column):
90 90 """Create a column (table already exists).
91 91
92 92 :param column: column object
93 93 :type column: :class:`sqlalchemy.Column` instance
94 94 """
95 95 if column.default is not None:
96 96 self.traverse_single(column.default)
97 97
98 98 table = self.start_alter_table(column)
99 99 self.append("ADD ")
100 100 self.append(self.get_column_specification(column))
101 101
102 102 for cons in column.constraints:
103 103 self.traverse_single(cons)
104 104 self.execute()
105 105
106 106 # ALTER TABLE STATEMENTS
107 107
108 108 # add indexes and unique constraints
109 109 if column.index_name:
110 110 Index(column.index_name,column).create()
111 111 elif column.unique_name:
112 112 constraint.UniqueConstraint(column,
113 113 name=column.unique_name).create()
114 114
115 115 # SA bounds FK constraints to table, add manually
116 116 for fk in column.foreign_keys:
117 117 self.add_foreignkey(fk.constraint)
118 118
119 119 # add primary key constraint if needed
120 120 if column.primary_key_name:
121 121 cons = constraint.PrimaryKeyConstraint(column,
122 122 name=column.primary_key_name)
123 123 cons.create()
124 124
125 125 def add_foreignkey(self, fk):
126 126 self.connection.execute(AddConstraint(fk))
127 127
128 128 class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
129 129 """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
130 130 DROP COLUMN``).
131 131 """
132 132
133 133 def visit_column(self, column):
134 134 """Drop a column from its table.
135 135
136 136 :param column: the column object
137 137 :type column: :class:`sqlalchemy.Column`
138 138 """
139 139 table = self.start_alter_table(column)
140 140 self.append('DROP COLUMN %s' % self.preparer.format_column(column))
141 141 self.execute()
142 142
143 143
144 144 class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
145 145 """Manages changes to existing schema elements.
146 146
147 147 Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
148 148 is in SchemaGenerator.
149 149
150 150 All items may be renamed. Columns can also have many of their properties -
151 151 type, for example - changed.
152 152
153 153 Each function is passed a tuple, containing (object, name); where
154 154 object is a type of object you'd expect for that function
155 155 (ie. table for visit_table) and name is the object's new
156 156 name. NONE means the name is unchanged.
157 157 """
158 158
159 159 def visit_table(self, table):
160 160 """Rename a table. Other ops aren't supported."""
161 161 self.start_alter_table(table)
162 162 q = util.safe_quote(table)
163 163 self.append("RENAME TO %s" % self.preparer.quote(table.new_name, q))
164 164 self.execute()
165 165
166 166 def visit_index(self, index):
167 167 """Rename an index"""
168 168 if hasattr(self, '_validate_identifier'):
169 169 # SA <= 0.6.3
170 170 self.append("ALTER INDEX %s RENAME TO %s" % (
171 171 self.preparer.quote(
172 172 self._validate_identifier(
173 173 index.name, True), index.quote),
174 174 self.preparer.quote(
175 175 self._validate_identifier(
176 176 index.new_name, True), index.quote)))
177 177 elif hasattr(self, '_index_identifier'):
178 178 # SA >= 0.6.5, < 0.8
179 179 self.append("ALTER INDEX %s RENAME TO %s" % (
180 180 self.preparer.quote(
181 181 self._index_identifier(
182 182 index.name), index.quote),
183 183 self.preparer.quote(
184 184 self._index_identifier(
185 185 index.new_name), index.quote)))
186 186 else:
187 187 # SA >= 0.8
188 188 class NewName(object):
189 189 """Map obj.name -> obj.new_name"""
190 190 def __init__(self, index):
191 191 self.name = index.new_name
192 192 self._obj = index
193 193
194 194 def __getattr__(self, attr):
195 195 if attr == 'name':
196 196 return getattr(self, attr)
197 197 return getattr(self._obj, attr)
198 198
199 199 self.append("ALTER INDEX %s RENAME TO %s" % (
200 200 self._prepared_index_name(index),
201 201 self._prepared_index_name(NewName(index))))
202 202
203 203 self.execute()
204 204
205 205 def visit_column(self, delta):
206 206 """Rename/change a column."""
207 207 # ALTER COLUMN is implemented as several ALTER statements
208 keys = delta.keys()
208 keys = list(delta.keys())
209 209 if 'type' in keys:
210 210 self._run_subvisit(delta, self._visit_column_type)
211 211 if 'nullable' in keys:
212 212 self._run_subvisit(delta, self._visit_column_nullable)
213 213 if 'server_default' in keys:
214 214 # Skip 'default': only handle server-side defaults, others
215 215 # are managed by the app, not the db.
216 216 self._run_subvisit(delta, self._visit_column_default)
217 217 if 'name' in keys:
218 218 self._run_subvisit(delta, self._visit_column_name, start_alter=False)
219 219
220 220 def _run_subvisit(self, delta, func, start_alter=True):
221 221 """Runs visit method based on what needs to be changed on column"""
222 222 table = self._to_table(delta.table)
223 223 col_name = delta.current_name
224 224 if start_alter:
225 225 self.start_alter_column(table, col_name)
226 226 ret = func(table, delta.result_column, delta)
227 227 self.execute()
228 228
229 229 def start_alter_column(self, table, col_name):
230 230 """Starts ALTER COLUMN"""
231 231 self.start_alter_table(table)
232 232 q = util.safe_quote(table)
233 233 self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, q))
234 234
235 235 def _visit_column_nullable(self, table, column, delta):
236 236 nullable = delta['nullable']
237 237 if nullable:
238 238 self.append("DROP NOT NULL")
239 239 else:
240 240 self.append("SET NOT NULL")
241 241
242 242 def _visit_column_default(self, table, column, delta):
243 243 default_text = self.get_column_default_string(column)
244 244 if default_text is not None:
245 245 self.append("SET DEFAULT %s" % default_text)
246 246 else:
247 247 self.append("DROP DEFAULT")
248 248
249 249 def _visit_column_type(self, table, column, delta):
250 250 type_ = delta['type']
251 251 type_text = str(type_.compile(dialect=self.dialect))
252 252 self.append("TYPE %s" % type_text)
253 253
254 254 def _visit_column_name(self, table, column, delta):
255 255 self.start_alter_table(table)
256 256 q = util.safe_quote(table)
257 257 col_name = self.preparer.quote(delta.current_name, q)
258 258 new_name = self.preparer.format_column(delta.result_column)
259 259 self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
260 260
261 261
262 262 class ANSIConstraintCommon(AlterTableVisitor):
263 263 """
264 264 Migrate's constraints require a separate creation function from
265 265 SA's: Migrate's constraints are created independently of a table;
266 266 SA's are created at the same time as the table.
267 267 """
268 268
269 269 def get_constraint_name(self, cons):
270 270 """Gets a name for the given constraint.
271 271
272 272 If the name is already set it will be used otherwise the
273 273 constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
274 274 method is used.
275 275
276 276 :param cons: constraint object
277 277 """
278 278 if cons.name is not None:
279 279 ret = cons.name
280 280 else:
281 281 ret = cons.name = cons.autoname()
282 282 return self.preparer.quote(ret, cons.quote)
283 283
284 284 def visit_migrate_primary_key_constraint(self, *p, **k):
285 285 self._visit_constraint(*p, **k)
286 286
287 287 def visit_migrate_foreign_key_constraint(self, *p, **k):
288 288 self._visit_constraint(*p, **k)
289 289
290 290 def visit_migrate_check_constraint(self, *p, **k):
291 291 self._visit_constraint(*p, **k)
292 292
293 293 def visit_migrate_unique_constraint(self, *p, **k):
294 294 self._visit_constraint(*p, **k)
295 295
296 296 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
297 297 def _visit_constraint(self, constraint):
298 298 constraint.name = self.get_constraint_name(constraint)
299 299 self.append(self.process(AddConstraint(constraint)))
300 300 self.execute()
301 301
302 302 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
303 303 def _visit_constraint(self, constraint):
304 304 constraint.name = self.get_constraint_name(constraint)
305 305 self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
306 306 self.execute()
307 307
308 308
309 309 class ANSIDialect(DefaultDialect):
310 310 columngenerator = ANSIColumnGenerator
311 311 columndropper = ANSIColumnDropper
312 312 schemachanger = ANSISchemaChanger
313 313 constraintgenerator = ANSIConstraintGenerator
314 314 constraintdropper = ANSIConstraintDropper
@@ -1,200 +1,200 b''
1 1 """
2 2 This module defines standalone schema constraint classes.
3 3 """
4 4 from sqlalchemy import schema
5 5
6 6 from rhodecode.lib.dbmigrate.migrate.exceptions import *
7 7
8 8
9 9 class ConstraintChangeset(object):
10 10 """Base class for Constraint classes."""
11 11
12 12 def _normalize_columns(self, cols, table_name=False):
13 13 """Given: column objects or names; return col names and
14 14 (maybe) a table"""
15 15 colnames = []
16 16 table = None
17 17 for col in cols:
18 18 if isinstance(col, schema.Column):
19 19 if col.table is not None and table is None:
20 20 table = col.table
21 21 if table_name:
22 22 col = '.'.join((col.table.name, col.name))
23 23 else:
24 24 col = col.name
25 25 colnames.append(col)
26 26 return colnames, table
27 27
28 28 def __do_imports(self, visitor_name, *a, **kw):
29 29 engine = kw.pop('engine', self.table.bind)
30 30 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (
31 31 get_engine_visitor, run_single_visitor)
32 32 visitorcallable = get_engine_visitor(engine, visitor_name)
33 33 run_single_visitor(engine, visitorcallable, self, *a, **kw)
34 34
35 35 def create(self, *a, **kw):
36 36 """Create the constraint in the database.
37 37
38 38 :param engine: the database engine to use. If this is \
39 39 :keyword:`None` the instance's engine will be used
40 40 :type engine: :class:`sqlalchemy.engine.base.Engine`
41 41 :param connection: reuse connection istead of creating new one.
42 42 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
43 43 """
44 44 # TODO: set the parent here instead of in __init__
45 45 self.__do_imports('constraintgenerator', *a, **kw)
46 46
47 47 def drop(self, *a, **kw):
48 48 """Drop the constraint from the database.
49 49
50 50 :param engine: the database engine to use. If this is
51 51 :keyword:`None` the instance's engine will be used
52 52 :param cascade: Issue CASCADE drop if database supports it
53 53 :type engine: :class:`sqlalchemy.engine.base.Engine`
54 54 :type cascade: bool
55 55 :param connection: reuse connection istead of creating new one.
56 56 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
57 57 :returns: Instance with cleared columns
58 58 """
59 59 self.cascade = kw.pop('cascade', False)
60 60 self.__do_imports('constraintdropper', *a, **kw)
61 61 # the spirit of Constraint objects is that they
62 62 # are immutable (just like in a DB. they're only ADDed
63 63 # or DROPped).
64 64 #self.columns.clear()
65 65 return self
66 66
67 67
68 68 class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
69 69 """Construct PrimaryKeyConstraint
70 70
71 71 Migrate's additional parameters:
72 72
73 73 :param cols: Columns in constraint.
74 74 :param table: If columns are passed as strings, this kw is required
75 75 :type table: Table instance
76 76 :type cols: strings or Column instances
77 77 """
78 78
79 79 __migrate_visit_name__ = 'migrate_primary_key_constraint'
80 80
81 81 def __init__(self, *cols, **kwargs):
82 82 colnames, table = self._normalize_columns(cols)
83 83 table = kwargs.pop('table', table)
84 84 super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs)
85 85 if table is not None:
86 86 self._set_parent(table)
87 87
88 88 def autoname(self):
89 89 """Mimic the database's automatic constraint names"""
90 90 return "%s_pkey" % self.table.name
91 91
92 92
93 93 class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
94 94 """Construct ForeignKeyConstraint
95 95
96 96 Migrate's additional parameters:
97 97
98 98 :param columns: Columns in constraint
99 99 :param refcolumns: Columns that this FK reffers to in another table.
100 100 :param table: If columns are passed as strings, this kw is required
101 101 :type table: Table instance
102 102 :type columns: list of strings or Column instances
103 103 :type refcolumns: list of strings or Column instances
104 104 """
105 105
106 106 __migrate_visit_name__ = 'migrate_foreign_key_constraint'
107 107
108 108 def __init__(self, columns, refcolumns, *args, **kwargs):
109 109 colnames, table = self._normalize_columns(columns)
110 110 table = kwargs.pop('table', table)
111 111 refcolnames, reftable = self._normalize_columns(refcolumns,
112 112 table_name=True)
113 113 super(ForeignKeyConstraint, self).__init__(
114 114 colnames, refcolnames, *args, **kwargs
115 115 )
116 116 if table is not None:
117 117 self._set_parent(table)
118 118
119 119 @property
120 120 def referenced(self):
121 121 return [e.column for e in self.elements]
122 122
123 123 @property
124 124 def reftable(self):
125 125 return self.referenced[0].table
126 126
127 127 def autoname(self):
128 128 """Mimic the database's automatic constraint names"""
129 129 if hasattr(self.columns, 'keys'):
130 130 # SA <= 0.5
131 firstcol = self.columns[self.columns.keys()[0]]
131 firstcol = self.columns[list(self.columns.keys())[0]]
132 132 ret = "%(table)s_%(firstcolumn)s_fkey" % {
133 133 'table': firstcol.table.name,
134 134 'firstcolumn': firstcol.name,}
135 135 else:
136 136 # SA >= 0.6
137 137 ret = "%(table)s_%(firstcolumn)s_fkey" % {
138 138 'table': self.table.name,
139 139 'firstcolumn': self.columns[0],}
140 140 return ret
141 141
142 142
143 143 class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
144 144 """Construct CheckConstraint
145 145
146 146 Migrate's additional parameters:
147 147
148 148 :param sqltext: Plain SQL text to check condition
149 149 :param columns: If not name is applied, you must supply this kw\
150 150 to autoname constraint
151 151 :param table: If columns are passed as strings, this kw is required
152 152 :type table: Table instance
153 153 :type columns: list of Columns instances
154 154 :type sqltext: string
155 155 """
156 156
157 157 __migrate_visit_name__ = 'migrate_check_constraint'
158 158
159 159 def __init__(self, sqltext, *args, **kwargs):
160 160 cols = kwargs.pop('columns', [])
161 161 if not cols and not kwargs.get('name', False):
162 162 raise InvalidConstraintError('You must either set "name"'
163 163 'parameter or "columns" to autogenarate it.')
164 164 colnames, table = self._normalize_columns(cols)
165 165 table = kwargs.pop('table', table)
166 166 schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
167 167 if table is not None:
168 168 self._set_parent(table)
169 169 self.colnames = colnames
170 170
171 171 def autoname(self):
172 172 return "%(table)s_%(cols)s_check" % \
173 173 {'table': self.table.name, 'cols': "_".join(self.colnames)}
174 174
175 175
176 176 class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
177 177 """Construct UniqueConstraint
178 178
179 179 Migrate's additional parameters:
180 180
181 181 :param cols: Columns in constraint.
182 182 :param table: If columns are passed as strings, this kw is required
183 183 :type table: Table instance
184 184 :type cols: strings or Column instances
185 185
186 186 .. versionadded:: 0.6.0
187 187 """
188 188
189 189 __migrate_visit_name__ = 'migrate_unique_constraint'
190 190
191 191 def __init__(self, *cols, **kwargs):
192 192 self.colnames, table = self._normalize_columns(cols)
193 193 table = kwargs.pop('table', table)
194 194 super(UniqueConstraint, self).__init__(*self.colnames, **kwargs)
195 195 if table is not None:
196 196 self._set_parent(table)
197 197
198 198 def autoname(self):
199 199 """Mimic the database's automatic constraint names"""
200 200 return "%s_%s_key" % (self.table.name, '_'.join(self.colnames))
@@ -1,108 +1,108 b''
1 1 """
2 2 Oracle database specific implementations of changeset classes.
3 3 """
4 4 import sqlalchemy as sa
5 5 from sqlalchemy.databases import oracle as sa_base
6 6
7 7 from rhodecode.lib.dbmigrate.migrate import exceptions
8 8 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql
9 9
10 10
11 11 OracleSchemaGenerator = sa_base.OracleDDLCompiler
12 12
13 13
14 14 class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
15 15 pass
16 16
17 17
18 18 class OracleColumnDropper(ansisql.ANSIColumnDropper):
19 19 pass
20 20
21 21
22 22 class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger):
23 23
24 24 def get_column_specification(self, column, **kwargs):
25 25 # Ignore the NOT NULL generated
26 26 override_nullable = kwargs.pop('override_nullable', None)
27 27 if override_nullable:
28 28 orig = column.nullable
29 29 column.nullable = True
30 30 ret = super(OracleSchemaChanger, self).get_column_specification(
31 31 column, **kwargs)
32 32 if override_nullable:
33 33 column.nullable = orig
34 34 return ret
35 35
36 36 def visit_column(self, delta):
37 keys = delta.keys()
37 keys = list(delta.keys())
38 38
39 39 if 'name' in keys:
40 40 self._run_subvisit(delta,
41 41 self._visit_column_name,
42 42 start_alter=False)
43 43
44 44 if len(set(('type', 'nullable', 'server_default')).intersection(keys)):
45 45 self._run_subvisit(delta,
46 46 self._visit_column_change,
47 47 start_alter=False)
48 48
49 49 def _visit_column_change(self, table, column, delta):
50 50 # Oracle cannot drop a default once created, but it can set it
51 51 # to null. We'll do that if default=None
52 52 # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234
53 53 dropdefault_hack = (column.server_default is None \
54 and 'server_default' in delta.keys())
54 and 'server_default' in list(delta.keys()))
55 55 # Oracle apparently doesn't like it when we say "not null" if
56 56 # the column's already not null. Fudge it, so we don't need a
57 57 # new function
58 58 notnull_hack = ((not column.nullable) \
59 and ('nullable' not in delta.keys()))
59 and ('nullable' not in list(delta.keys())))
60 60 # We need to specify NULL if we're removing a NOT NULL
61 61 # constraint
62 null_hack = (column.nullable and ('nullable' in delta.keys()))
62 null_hack = (column.nullable and ('nullable' in list(delta.keys())))
63 63
64 64 if dropdefault_hack:
65 65 column.server_default = sa.PassiveDefault(sa.sql.null())
66 66 if notnull_hack:
67 67 column.nullable = True
68 68 colspec = self.get_column_specification(column,
69 69 override_nullable=null_hack)
70 70 if null_hack:
71 71 colspec += ' NULL'
72 72 if notnull_hack:
73 73 column.nullable = False
74 74 if dropdefault_hack:
75 75 column.server_default = None
76 76
77 77 self.start_alter_table(table)
78 78 self.append("MODIFY (")
79 79 self.append(colspec)
80 80 self.append(")")
81 81
82 82
83 83 class OracleConstraintCommon(object):
84 84
85 85 def get_constraint_name(self, cons):
86 86 # Oracle constraints can't guess their name like other DBs
87 87 if not cons.name:
88 88 raise exceptions.NotSupportedError(
89 89 "Oracle constraint names must be explicitly stated")
90 90 return cons.name
91 91
92 92
93 93 class OracleConstraintGenerator(OracleConstraintCommon,
94 94 ansisql.ANSIConstraintGenerator):
95 95 pass
96 96
97 97
98 98 class OracleConstraintDropper(OracleConstraintCommon,
99 99 ansisql.ANSIConstraintDropper):
100 100 pass
101 101
102 102
103 103 class OracleDialect(ansisql.ANSIDialect):
104 104 columngenerator = OracleColumnGenerator
105 105 columndropper = OracleColumnDropper
106 106 schemachanger = OracleSchemaChanger
107 107 constraintgenerator = OracleConstraintGenerator
108 108 constraintdropper = OracleConstraintDropper
@@ -1,668 +1,668 b''
1 1 """
2 2 Schema module providing common schema operations.
3 3 """
4 4 import abc
5 5 try: # Python 3
6 6 from collections.abc import MutableMapping as DictMixin
7 7 except ImportError: # Python 2
8 8 from UserDict import DictMixin
9 9 import warnings
10 10
11 11 import sqlalchemy
12 12
13 13 from sqlalchemy.schema import ForeignKeyConstraint
14 14 from sqlalchemy.schema import UniqueConstraint
15 15
16 16 from rhodecode.lib.dbmigrate.migrate.exceptions import *
17 17 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_07, SQLA_08
18 18 from rhodecode.lib.dbmigrate.migrate.changeset import util
19 19 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (
20 20 get_engine_visitor, run_single_visitor)
21 21
22 22
23 23 __all__ = [
24 24 'create_column',
25 25 'drop_column',
26 26 'alter_column',
27 27 'rename_table',
28 28 'rename_index',
29 29 'ChangesetTable',
30 30 'ChangesetColumn',
31 31 'ChangesetIndex',
32 32 'ChangesetDefaultClause',
33 33 'ColumnDelta',
34 34 ]
35 35
36 36 def create_column(column, table=None, *p, **kw):
37 37 """Create a column, given the table.
38 38
39 39 API to :meth:`ChangesetColumn.create`.
40 40 """
41 41 if table is not None:
42 42 return table.create_column(column, *p, **kw)
43 43 return column.create(*p, **kw)
44 44
45 45
46 46 def drop_column(column, table=None, *p, **kw):
47 47 """Drop a column, given the table.
48 48
49 49 API to :meth:`ChangesetColumn.drop`.
50 50 """
51 51 if table is not None:
52 52 return table.drop_column(column, *p, **kw)
53 53 return column.drop(*p, **kw)
54 54
55 55
56 56 def rename_table(table, name, engine=None, **kw):
57 57 """Rename a table.
58 58
59 59 If Table instance is given, engine is not used.
60 60
61 61 API to :meth:`ChangesetTable.rename`.
62 62
63 63 :param table: Table to be renamed.
64 64 :param name: New name for Table.
65 65 :param engine: Engine instance.
66 66 :type table: string or Table instance
67 67 :type name: string
68 68 :type engine: obj
69 69 """
70 70 table = _to_table(table, engine)
71 71 table.rename(name, **kw)
72 72
73 73
74 74 def rename_index(index, name, table=None, engine=None, **kw):
75 75 """Rename an index.
76 76
77 77 If Index instance is given,
78 78 table and engine are not used.
79 79
80 80 API to :meth:`ChangesetIndex.rename`.
81 81
82 82 :param index: Index to be renamed.
83 83 :param name: New name for index.
84 84 :param table: Table to which Index is reffered.
85 85 :param engine: Engine instance.
86 86 :type index: string or Index instance
87 87 :type name: string
88 88 :type table: string or Table instance
89 89 :type engine: obj
90 90 """
91 91 index = _to_index(index, table, engine)
92 92 index.rename(name, **kw)
93 93
94 94
95 95 def alter_column(*p, **k):
96 96 """Alter a column.
97 97
98 98 This is a helper function that creates a :class:`ColumnDelta` and
99 99 runs it.
100 100
101 101 :argument column:
102 102 The name of the column to be altered or a
103 103 :class:`ChangesetColumn` column representing it.
104 104
105 105 :param table:
106 106 A :class:`~sqlalchemy.schema.Table` or table name to
107 107 for the table where the column will be changed.
108 108
109 109 :param engine:
110 110 The :class:`~sqlalchemy.engine.base.Engine` to use for table
111 111 reflection and schema alterations.
112 112
113 113 :returns: A :class:`ColumnDelta` instance representing the change.
114 114
115 115
116 116 """
117 117
118 118 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
119 119 k['table'] = p[0].table
120 120 if 'engine' not in k:
121 121 k['engine'] = k['table'].bind
122 122
123 123 # deprecation
124 124 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
125 125 warnings.warn(
126 126 "Passing a Column object to alter_column is deprecated."
127 127 " Just pass in keyword parameters instead.",
128 128 MigrateDeprecationWarning
129 129 )
130 130 engine = k['engine']
131 131
132 132 # enough tests seem to break when metadata is always altered
133 133 # that this crutch has to be left in until they can be sorted
134 134 # out
135 135 k['alter_metadata']=True
136 136
137 137 delta = ColumnDelta(*p, **k)
138 138
139 139 visitorcallable = get_engine_visitor(engine, 'schemachanger')
140 140 engine._run_visitor(visitorcallable, delta)
141 141
142 142 return delta
143 143
144 144
145 145 def _to_table(table, engine=None):
146 146 """Return if instance of Table, else construct new with metadata"""
147 147 if isinstance(table, sqlalchemy.Table):
148 148 return table
149 149
150 150 # Given: table name, maybe an engine
151 151 meta = sqlalchemy.MetaData()
152 152 if engine is not None:
153 153 meta.bind = engine
154 154 return sqlalchemy.Table(table, meta)
155 155
156 156
157 157 def _to_index(index, table=None, engine=None):
158 158 """Return if instance of Index, else construct new with metadata"""
159 159 if isinstance(index, sqlalchemy.Index):
160 160 return index
161 161
162 162 # Given: index name; table name required
163 163 table = _to_table(table, engine)
164 164 ret = sqlalchemy.Index(index)
165 165 ret.table = table
166 166 return ret
167 167
168 168
169 169 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
170 170 """Extracts the differences between two columns/column-parameters
171 171
172 172 May receive parameters arranged in several different ways:
173 173
174 174 * **current_column, new_column, \*p, \*\*kw**
175 175 Additional parameters can be specified to override column
176 176 differences.
177 177
178 178 * **current_column, \*p, \*\*kw**
179 179 Additional parameters alter current_column. Table name is extracted
180 180 from current_column object.
181 181 Name is changed to current_column.name from current_name,
182 182 if current_name is specified.
183 183
184 184 * **current_col_name, \*p, \*\*kw**
185 185 Table kw must specified.
186 186
187 187 :param table: Table at which current Column should be bound to.\
188 188 If table name is given, reflection will be used.
189 189 :type table: string or Table instance
190 190
191 191 :param metadata: A :class:`MetaData` instance to store
192 192 reflected table names
193 193
194 194 :param engine: When reflecting tables, either engine or metadata must \
195 195 be specified to acquire engine object.
196 196 :type engine: :class:`Engine` instance
197 197 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
198 198 `result_column` through :func:`dict` alike object.
199 199
200 200 * :class:`ColumnDelta`.result_column is altered column with new attributes
201 201
202 202 * :class:`ColumnDelta`.current_name is current name of column in db
203 203
204 204
205 205 """
206 206
207 207 # Column attributes that can be altered
208 208 diff_keys = ('name', 'type', 'primary_key', 'nullable',
209 209 'server_onupdate', 'server_default', 'autoincrement')
210 210 diffs = dict()
211 211 __visit_name__ = 'column'
212 212
213 213 def __init__(self, *p, **kw):
214 214 # 'alter_metadata' is not a public api. It exists purely
215 215 # as a crutch until the tests that fail when 'alter_metadata'
216 216 # behaviour always happens can be sorted out
217 217 self.alter_metadata = kw.pop("alter_metadata", False)
218 218
219 219 self.meta = kw.pop("metadata", None)
220 220 self.engine = kw.pop("engine", None)
221 221
222 222 # Things are initialized differently depending on how many column
223 223 # parameters are given. Figure out how many and call the appropriate
224 224 # method.
225 225 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
226 226 # At least one column specified
227 227 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
228 228 # Two columns specified
229 229 diffs = self.compare_2_columns(*p, **kw)
230 230 else:
231 231 # Exactly one column specified
232 232 diffs = self.compare_1_column(*p, **kw)
233 233 else:
234 234 # Zero columns specified
235 235 if not len(p) or not isinstance(p[0], str):
236 236 raise ValueError("First argument must be column name")
237 237 diffs = self.compare_parameters(*p, **kw)
238 238
239 239 self.apply_diffs(diffs)
240 240
241 241 def __repr__(self):
242 242 return '<ColumnDelta altermetadata=%r, %s>' % (
243 243 self.alter_metadata,
244 244 super(ColumnDelta, self).__repr__()
245 245 )
246 246
247 247 def __getitem__(self, key):
248 if key not in self.keys():
248 if key not in list(self.keys()):
249 249 raise KeyError("No such diff key, available: %s" % self.diffs )
250 250 return getattr(self.result_column, key)
251 251
252 252 def __setitem__(self, key, value):
253 if key not in self.keys():
253 if key not in list(self.keys()):
254 254 raise KeyError("No such diff key, available: %s" % self.diffs )
255 255 setattr(self.result_column, key, value)
256 256
257 257 def __delitem__(self, key):
258 258 raise NotImplementedError
259 259
260 260 def __len__(self):
261 261 raise NotImplementedError
262 262
263 263 def __iter__(self):
264 264 raise NotImplementedError
265 265
266 266 def keys(self):
267 return self.diffs.keys()
267 return list(self.diffs.keys())
268 268
269 269 def compare_parameters(self, current_name, *p, **k):
270 270 """Compares Column objects with reflection"""
271 271 self.table = k.pop('table')
272 272 self.result_column = self._table.c.get(current_name)
273 273 if len(p):
274 274 k = self._extract_parameters(p, k, self.result_column)
275 275 return k
276 276
277 277 def compare_1_column(self, col, *p, **k):
278 278 """Compares one Column object"""
279 279 self.table = k.pop('table', None)
280 280 if self.table is None:
281 281 self.table = col.table
282 282 self.result_column = col
283 283 if len(p):
284 284 k = self._extract_parameters(p, k, self.result_column)
285 285 return k
286 286
287 287 def compare_2_columns(self, old_col, new_col, *p, **k):
288 288 """Compares two Column objects"""
289 289 self.process_column(new_col)
290 290 self.table = k.pop('table', None)
291 291 # we cannot use bool() on table in SA06
292 292 if self.table is None:
293 293 self.table = old_col.table
294 294 if self.table is None:
295 295 new_col.table
296 296 self.result_column = old_col
297 297
298 298 # set differences
299 299 # leave out some stuff for later comp
300 300 for key in (set(self.diff_keys) - set(('type',))):
301 301 val = getattr(new_col, key, None)
302 302 if getattr(self.result_column, key, None) != val:
303 303 k.setdefault(key, val)
304 304
305 305 # inspect types
306 306 if not self.are_column_types_eq(self.result_column.type, new_col.type):
307 307 k.setdefault('type', new_col.type)
308 308
309 309 if len(p):
310 310 k = self._extract_parameters(p, k, self.result_column)
311 311 return k
312 312
313 313 def apply_diffs(self, diffs):
314 314 """Populate dict and column object with new values"""
315 315 self.diffs = diffs
316 316 for key in self.diff_keys:
317 317 if key in diffs:
318 318 setattr(self.result_column, key, diffs[key])
319 319
320 320 self.process_column(self.result_column)
321 321
322 322 # create an instance of class type if not yet
323 323 if 'type' in diffs and callable(self.result_column.type):
324 324 self.result_column.type = self.result_column.type()
325 325
326 326 # add column to the table
327 327 if self.table is not None and self.alter_metadata:
328 328 self.result_column.add_to_table(self.table)
329 329
330 330 def are_column_types_eq(self, old_type, new_type):
331 331 """Compares two types to be equal"""
332 332 ret = old_type.__class__ == new_type.__class__
333 333
334 334 # String length is a special case
335 335 if ret and isinstance(new_type, sqlalchemy.types.String):
336 336 ret = (getattr(old_type, 'length', None) == \
337 337 getattr(new_type, 'length', None))
338 338 return ret
339 339
340 340 def _extract_parameters(self, p, k, column):
341 341 """Extracts data from p and modifies diffs"""
342 342 p = list(p)
343 343 while len(p):
344 344 if isinstance(p[0], str):
345 345 k.setdefault('name', p.pop(0))
346 346 elif isinstance(p[0], sqlalchemy.types.TypeEngine):
347 347 k.setdefault('type', p.pop(0))
348 348 elif callable(p[0]):
349 349 p[0] = p[0]()
350 350 else:
351 351 break
352 352
353 353 if len(p):
354 354 new_col = column.copy_fixed()
355 355 new_col._init_items(*p)
356 356 k = self.compare_2_columns(column, new_col, **k)
357 357 return k
358 358
359 359 def process_column(self, column):
360 360 """Processes default values for column"""
361 361 # XXX: this is a snippet from SA processing of positional parameters
362 362 toinit = list()
363 363
364 364 if column.server_default is not None:
365 365 if isinstance(column.server_default, sqlalchemy.FetchedValue):
366 366 toinit.append(column.server_default)
367 367 else:
368 368 toinit.append(sqlalchemy.DefaultClause(column.server_default))
369 369 if column.server_onupdate is not None:
370 370 if isinstance(column.server_onupdate, FetchedValue):
371 371 toinit.append(column.server_default)
372 372 else:
373 373 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
374 374 for_update=True))
375 375 if toinit:
376 376 column._init_items(*toinit)
377 377
378 378 def _get_table(self):
379 379 return getattr(self, '_table', None)
380 380
381 381 def _set_table(self, table):
382 382 if isinstance(table, str):
383 383 if self.alter_metadata:
384 384 if not self.meta:
385 385 raise ValueError("metadata must be specified for table"
386 386 " reflection when using alter_metadata")
387 387 meta = self.meta
388 388 if self.engine:
389 389 meta.bind = self.engine
390 390 else:
391 391 if not self.engine and not self.meta:
392 392 raise ValueError("engine or metadata must be specified"
393 393 " to reflect tables")
394 394 if not self.engine:
395 395 self.engine = self.meta.bind
396 396 meta = sqlalchemy.MetaData(bind=self.engine)
397 397 self._table = sqlalchemy.Table(table, meta, autoload=True)
398 398 elif isinstance(table, sqlalchemy.Table):
399 399 self._table = table
400 400 if not self.alter_metadata:
401 401 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
402 402 def _get_result_column(self):
403 403 return getattr(self, '_result_column', None)
404 404
405 405 def _set_result_column(self, column):
406 406 """Set Column to Table based on alter_metadata evaluation."""
407 407 self.process_column(column)
408 408 if not hasattr(self, 'current_name'):
409 409 self.current_name = column.name
410 410 if self.alter_metadata:
411 411 self._result_column = column
412 412 else:
413 413 self._result_column = column.copy_fixed()
414 414
415 415 table = property(_get_table, _set_table)
416 416 result_column = property(_get_result_column, _set_result_column)
417 417
418 418
419 419 class ChangesetTable(object):
420 420 """Changeset extensions to SQLAlchemy tables."""
421 421
422 422 def create_column(self, column, *p, **kw):
423 423 """Creates a column.
424 424
425 425 The column parameter may be a column definition or the name of
426 426 a column in this table.
427 427
428 428 API to :meth:`ChangesetColumn.create`
429 429
430 430 :param column: Column to be created
431 431 :type column: Column instance or string
432 432 """
433 433 if not isinstance(column, sqlalchemy.Column):
434 434 # It's a column name
435 435 column = getattr(self.c, str(column))
436 436 column.create(table=self, *p, **kw)
437 437
438 438 def drop_column(self, column, *p, **kw):
439 439 """Drop a column, given its name or definition.
440 440
441 441 API to :meth:`ChangesetColumn.drop`
442 442
443 443 :param column: Column to be droped
444 444 :type column: Column instance or string
445 445 """
446 446 if not isinstance(column, sqlalchemy.Column):
447 447 # It's a column name
448 448 try:
449 449 column = getattr(self.c, str(column))
450 450 except AttributeError:
451 451 # That column isn't part of the table. We don't need
452 452 # its entire definition to drop the column, just its
453 453 # name, so create a dummy column with the same name.
454 454 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
455 455 column.drop(table=self, *p, **kw)
456 456
457 457 def rename(self, name, connection=None, **kwargs):
458 458 """Rename this table.
459 459
460 460 :param name: New name of the table.
461 461 :type name: string
462 462 :param connection: reuse connection istead of creating new one.
463 463 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
464 464 """
465 465 engine = self.bind
466 466 self.new_name = name
467 467 visitorcallable = get_engine_visitor(engine, 'schemachanger')
468 468 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
469 469
470 470 # Fix metadata registration
471 471 self.name = name
472 472 self.deregister()
473 473 self._set_parent(self.metadata)
474 474
475 475 def _meta_key(self):
476 476 """Get the meta key for this table."""
477 477 return sqlalchemy.schema._get_table_key(self.name, self.schema)
478 478
479 479 def deregister(self):
480 480 """Remove this table from its metadata"""
481 481 if SQLA_07:
482 482 self.metadata._remove_table(self.name, self.schema)
483 483 else:
484 484 key = self._meta_key()
485 485 meta = self.metadata
486 486 if key in meta.tables:
487 487 del meta.tables[key]
488 488
489 489
490 490 class ChangesetColumn(object):
491 491 """Changeset extensions to SQLAlchemy columns."""
492 492
493 493 def alter(self, *p, **k):
494 494 """Makes a call to :func:`alter_column` for the column this
495 495 method is called on.
496 496 """
497 497 if 'table' not in k:
498 498 k['table'] = self.table
499 499 if 'engine' not in k:
500 500 k['engine'] = k['table'].bind
501 501 return alter_column(self, *p, **k)
502 502
503 503 def create(self, table=None, index_name=None, unique_name=None,
504 504 primary_key_name=None, populate_default=True, connection=None, **kwargs):
505 505 """Create this column in the database.
506 506
507 507 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
508 508 for most databases.
509 509
510 510 :param table: Table instance to create on.
511 511 :param index_name: Creates :class:`ChangesetIndex` on this column.
512 512 :param unique_name: Creates :class:\
513 513 `~migrate.changeset.constraint.UniqueConstraint` on this column.
514 514 :param primary_key_name: Creates :class:\
515 515 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
516 516 :param populate_default: If True, created column will be \
517 517 populated with defaults
518 518 :param connection: reuse connection istead of creating new one.
519 519 :type table: Table instance
520 520 :type index_name: string
521 521 :type unique_name: string
522 522 :type primary_key_name: string
523 523 :type populate_default: bool
524 524 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
525 525
526 526 :returns: self
527 527 """
528 528 self.populate_default = populate_default
529 529 self.index_name = index_name
530 530 self.unique_name = unique_name
531 531 self.primary_key_name = primary_key_name
532 532 for cons in ('index_name', 'unique_name', 'primary_key_name'):
533 533 self._check_sanity_constraints(cons)
534 534
535 535 self.add_to_table(table)
536 536 engine = self.table.bind
537 537 visitorcallable = get_engine_visitor(engine, 'columngenerator')
538 538 engine._run_visitor(visitorcallable, self, connection, **kwargs)
539 539
540 540 # TODO: reuse existing connection
541 541 if self.populate_default and self.default is not None:
542 542 stmt = table.update().values({self: engine._execute_default(self.default)})
543 543 engine.execute(stmt)
544 544
545 545 return self
546 546
547 547 def drop(self, table=None, connection=None, **kwargs):
548 548 """Drop this column from the database, leaving its table intact.
549 549
550 550 ``ALTER TABLE DROP COLUMN``, for most databases.
551 551
552 552 :param connection: reuse connection istead of creating new one.
553 553 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
554 554 """
555 555 if table is not None:
556 556 self.table = table
557 557 engine = self.table.bind
558 558 visitorcallable = get_engine_visitor(engine, 'columndropper')
559 559 engine._run_visitor(visitorcallable, self, connection, **kwargs)
560 560 self.remove_from_table(self.table, unset_table=False)
561 561 self.table = None
562 562 return self
563 563
564 564 def add_to_table(self, table):
565 565 if table is not None and self.table is None:
566 566 if SQLA_07:
567 567 table.append_column(self)
568 568 else:
569 569 self._set_parent(table)
570 570
571 571 def _col_name_in_constraint(self,cons,name):
572 572 return False
573 573
574 574 def remove_from_table(self, table, unset_table=True):
575 575 # TODO: remove primary keys, constraints, etc
576 576 if unset_table:
577 577 self.table = None
578 578
579 579 to_drop = set()
580 580 for index in table.indexes:
581 581 columns = []
582 582 for col in index.columns:
583 583 if col.name!=self.name:
584 584 columns.append(col)
585 585 if columns:
586 586 index.columns = columns
587 587 if SQLA_08:
588 588 index.expressions = columns
589 589 else:
590 590 to_drop.add(index)
591 591 table.indexes = table.indexes - to_drop
592 592
593 593 to_drop = set()
594 594 for cons in table.constraints:
595 595 # TODO: deal with other types of constraint
596 596 if isinstance(cons,(ForeignKeyConstraint,
597 597 UniqueConstraint)):
598 598 for col_name in cons.columns:
599 599 if not isinstance(col_name, str):
600 600 col_name = col_name.name
601 601 if self.name==col_name:
602 602 to_drop.add(cons)
603 603 table.constraints = table.constraints - to_drop
604 604
605 605 if table.c.contains_column(self):
606 606 if SQLA_07:
607 607 table._columns.remove(self)
608 608 else:
609 609 table.c.remove(self)
610 610
611 611 # TODO: this is fixed in 0.6
612 612 def copy_fixed(self, **kw):
613 613 """Create a copy of this ``Column``, with all attributes."""
614 614 q = util.safe_quote(self)
615 615 return sqlalchemy.Column(self.name, self.type, self.default,
616 616 key=self.key,
617 617 primary_key=self.primary_key,
618 618 nullable=self.nullable,
619 619 quote=q,
620 620 index=self.index,
621 621 unique=self.unique,
622 622 onupdate=self.onupdate,
623 623 autoincrement=self.autoincrement,
624 624 server_default=self.server_default,
625 625 server_onupdate=self.server_onupdate,
626 626 *[c.copy(**kw) for c in self.constraints])
627 627
628 628 def _check_sanity_constraints(self, name):
629 629 """Check if constraints names are correct"""
630 630 obj = getattr(self, name)
631 631 if (getattr(self, name[:-5]) and not obj):
632 632 raise InvalidConstraintError("Column.create() accepts index_name,"
633 633 " primary_key_name and unique_name to generate constraints")
634 634 if not isinstance(obj, str) and obj is not None:
635 635 raise InvalidConstraintError(
636 636 "%s argument for column must be constraint name" % name)
637 637
638 638
639 639 class ChangesetIndex(object):
640 640 """Changeset extensions to SQLAlchemy Indexes."""
641 641
642 642 __visit_name__ = 'index'
643 643
644 644 def rename(self, name, connection=None, **kwargs):
645 645 """Change the name of an index.
646 646
647 647 :param name: New name of the Index.
648 648 :type name: string
649 649 :param connection: reuse connection istead of creating new one.
650 650 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
651 651 """
652 652 engine = self.table.bind
653 653 self.new_name = name
654 654 visitorcallable = get_engine_visitor(engine, 'schemachanger')
655 655 engine._run_visitor(visitorcallable, self, connection, **kwargs)
656 656 self.name = name
657 657
658 658
659 659 class ChangesetDefaultClause(object):
660 660 """Implements comparison between :class:`DefaultClause` instances"""
661 661
662 662 def __eq__(self, other):
663 663 if isinstance(other, self.__class__):
664 664 if self.arg == other.arg:
665 665 return True
666 666
667 667 def __ne__(self, other):
668 668 return not self.__eq__(other)
@@ -1,386 +1,386 b''
1 1 """
2 2 This module provides an external API to the versioning system.
3 3
4 4 .. versionchanged:: 0.6.0
5 5 :func:`migrate.versioning.api.test` and schema diff functions
6 6 changed order of positional arguments so all accept `url` and `repository`
7 7 as first arguments.
8 8
9 9 .. versionchanged:: 0.5.4
10 10 ``--preview_sql`` displays source file when using SQL scripts.
11 11 If Python script is used, it runs the action with mocked engine and
12 12 returns captured SQL statements.
13 13
14 14 .. versionchanged:: 0.5.4
15 15 Deprecated ``--echo`` parameter in favour of new
16 16 :func:`migrate.versioning.util.construct_engine` behavior.
17 17 """
18 18
19 19 # Dear migrate developers,
20 20 #
21 21 # please do not comment this module using sphinx syntax because its
22 22 # docstrings are presented as user help and most users cannot
23 23 # interpret sphinx annotated ReStructuredText.
24 24 #
25 25 # Thanks,
26 26 # Jan Dittberner
27 27
28 28 import sys
29 29 import inspect
30 30 import logging
31 31
32 32 from rhodecode.lib.dbmigrate.migrate import exceptions
33 33 from rhodecode.lib.dbmigrate.migrate.versioning import (
34 34 repository, schema, version,
35 35 script as script_ # command name conflict
36 36 )
37 37 from rhodecode.lib.dbmigrate.migrate.versioning.util import (
38 38 catch_known_errors, with_engine)
39 39
40 40
41 41 log = logging.getLogger(__name__)
42 42 command_desc = {
43 43 'help': 'displays help on a given command',
44 44 'create': 'create an empty repository at the specified path',
45 45 'script': 'create an empty change Python script',
46 46 'script_sql': 'create empty change SQL scripts for given database',
47 47 'version': 'display the latest version available in a repository',
48 48 'db_version': 'show the current version of the repository under version control',
49 49 'source': 'display the Python code for a particular version in this repository',
50 50 'version_control': 'mark a database as under this repository\'s version control',
51 51 'upgrade': 'upgrade a database to a later version',
52 52 'downgrade': 'downgrade a database to an earlier version',
53 53 'drop_version_control': 'removes version control from a database',
54 54 'manage': 'creates a Python script that runs Migrate with a set of default values',
55 55 'test': 'performs the upgrade and downgrade command on the given database',
56 56 'compare_model_to_db': 'compare MetaData against the current database state',
57 57 'create_model': 'dump the current database as a Python model to stdout',
58 58 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
59 59 'update_db_from_model': 'modify the database to match the structure of the current MetaData',
60 60 }
61 __all__ = command_desc.keys()
61 __all__ = list(command_desc.keys())
62 62
63 63 Repository = repository.Repository
64 64 ControlledSchema = schema.ControlledSchema
65 65 VerNum = version.VerNum
66 66 PythonScript = script_.PythonScript
67 67 SqlScript = script_.SqlScript
68 68
69 69
70 70 # deprecated
71 71 def help(cmd=None, **opts):
72 72 """%prog help COMMAND
73 73
74 74 Displays help on a given command.
75 75 """
76 76 if cmd is None:
77 77 raise exceptions.UsageError(None)
78 78 try:
79 79 func = globals()[cmd]
80 80 except:
81 81 raise exceptions.UsageError(
82 82 "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
83 83 ret = func.__doc__
84 84 if sys.argv[0]:
85 85 ret = ret.replace('%prog', sys.argv[0])
86 86 return ret
87 87
88 88 @catch_known_errors
89 89 def create(repository, name, **opts):
90 90 """%prog create REPOSITORY_PATH NAME [--table=TABLE]
91 91
92 92 Create an empty repository at the specified path.
93 93
94 94 You can specify the version_table to be used; by default, it is
95 95 'migrate_version'. This table is created in all version-controlled
96 96 databases.
97 97 """
98 98 repo_path = Repository.create(repository, name, **opts)
99 99
100 100
101 101 @catch_known_errors
102 102 def script(description, repository, **opts):
103 103 """%prog script DESCRIPTION REPOSITORY_PATH
104 104
105 105 Create an empty change script using the next unused version number
106 106 appended with the given description.
107 107
108 108 For instance, manage.py script "Add initial tables" creates:
109 109 repository/versions/001_Add_initial_tables.py
110 110 """
111 111 repo = Repository(repository)
112 112 repo.create_script(description, **opts)
113 113
114 114
115 115 @catch_known_errors
116 116 def script_sql(database, description, repository, **opts):
117 117 """%prog script_sql DATABASE DESCRIPTION REPOSITORY_PATH
118 118
119 119 Create empty change SQL scripts for given DATABASE, where DATABASE
120 120 is either specific ('postgresql', 'mysql', 'oracle', 'sqlite', etc.)
121 121 or generic ('default').
122 122
123 123 For instance, manage.py script_sql postgresql description creates:
124 124 repository/versions/001_description_postgresql_upgrade.sql and
125 125 repository/versions/001_description_postgresql_downgrade.sql
126 126 """
127 127 repo = Repository(repository)
128 128 repo.create_script_sql(database, description, **opts)
129 129
130 130
131 131 def version(repository, **opts):
132 132 """%prog version REPOSITORY_PATH
133 133
134 134 Display the latest version available in a repository.
135 135 """
136 136 repo = Repository(repository)
137 137 return repo.latest
138 138
139 139
140 140 @with_engine
141 141 def db_version(url, repository, **opts):
142 142 """%prog db_version URL REPOSITORY_PATH
143 143
144 144 Show the current version of the repository with the given
145 145 connection string, under version control of the specified
146 146 repository.
147 147
148 148 The url should be any valid SQLAlchemy connection string.
149 149 """
150 150 engine = opts.pop('engine')
151 151 schema = ControlledSchema(engine, repository)
152 152 return schema.version
153 153
154 154
155 155 def source(version, dest=None, repository=None, **opts):
156 156 """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
157 157
158 158 Display the Python code for a particular version in this
159 159 repository. Save it to the file at DESTINATION or, if omitted,
160 160 send to stdout.
161 161 """
162 162 if repository is None:
163 163 raise exceptions.UsageError("A repository must be specified")
164 164 repo = Repository(repository)
165 165 ret = repo.version(version).script().source()
166 166 if dest is not None:
167 167 with open(dest, 'w') as f:
168 168 f.write(ret)
169 169 ret = None
170 170 return ret
171 171
172 172
173 173 def upgrade(url, repository, version=None, **opts):
174 174 """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
175 175
176 176 Upgrade a database to a later version.
177 177
178 178 This runs the upgrade() function defined in your change scripts.
179 179
180 180 By default, the database is updated to the latest available
181 181 version. You may specify a version instead, if you wish.
182 182
183 183 You may preview the Python or SQL code to be executed, rather than
184 184 actually executing it, using the appropriate 'preview' option.
185 185 """
186 186 err = "Cannot upgrade a database of version %s to version %s. "\
187 187 "Try 'downgrade' instead."
188 188 return _migrate(url, repository, version, upgrade=True, err=err, **opts)
189 189
190 190
191 191 def downgrade(url, repository, version, **opts):
192 192 """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
193 193
194 194 Downgrade a database to an earlier version.
195 195
196 196 This is the reverse of upgrade; this runs the downgrade() function
197 197 defined in your change scripts.
198 198
199 199 You may preview the Python or SQL code to be executed, rather than
200 200 actually executing it, using the appropriate 'preview' option.
201 201 """
202 202 err = "Cannot downgrade a database of version %s to version %s. "\
203 203 "Try 'upgrade' instead."
204 204 return _migrate(url, repository, version, upgrade=False, err=err, **opts)
205 205
206 206 @with_engine
207 207 def test(url, repository, **opts):
208 208 """%prog test URL REPOSITORY_PATH [VERSION]
209 209
210 210 Performs the upgrade and downgrade option on the given
211 211 database. This is not a real test and may leave the database in a
212 212 bad state. You should therefore better run the test on a copy of
213 213 your database.
214 214 """
215 215 engine = opts.pop('engine')
216 216 repos = Repository(repository)
217 217
218 218 # Upgrade
219 219 log.info("Upgrading...")
220 220 script = repos.version(None).script(engine.name, 'upgrade')
221 221 script.run(engine, 1)
222 222 log.info("done")
223 223
224 224 log.info("Downgrading...")
225 225 script = repos.version(None).script(engine.name, 'downgrade')
226 226 script.run(engine, -1)
227 227 log.info("done")
228 228 log.info("Success")
229 229
230 230
231 231 @with_engine
232 232 def version_control(url, repository, version=None, **opts):
233 233 """%prog version_control URL REPOSITORY_PATH [VERSION]
234 234
235 235 Mark a database as under this repository's version control.
236 236
237 237 Once a database is under version control, schema changes should
238 238 only be done via change scripts in this repository.
239 239
240 240 This creates the table version_table in the database.
241 241
242 242 The url should be any valid SQLAlchemy connection string.
243 243
244 244 By default, the database begins at version 0 and is assumed to be
245 245 empty. If the database is not empty, you may specify a version at
246 246 which to begin instead. No attempt is made to verify this
247 247 version's correctness - the database schema is expected to be
248 248 identical to what it would be if the database were created from
249 249 scratch.
250 250 """
251 251 engine = opts.pop('engine')
252 252 ControlledSchema.create(engine, repository, version)
253 253
254 254
255 255 @with_engine
256 256 def drop_version_control(url, repository, **opts):
257 257 """%prog drop_version_control URL REPOSITORY_PATH
258 258
259 259 Removes version control from a database.
260 260 """
261 261 engine = opts.pop('engine')
262 262 schema = ControlledSchema(engine, repository)
263 263 schema.drop()
264 264
265 265
266 266 def manage(file, **opts):
267 267 """%prog manage FILENAME [VARIABLES...]
268 268
269 269 Creates a script that runs Migrate with a set of default values.
270 270
271 271 For example::
272 272
273 273 %prog manage manage.py --repository=/path/to/repository \
274 274 --url=sqlite:///project.db
275 275
276 276 would create the script manage.py. The following two commands
277 277 would then have exactly the same results::
278 278
279 279 python manage.py version
280 280 %prog version --repository=/path/to/repository
281 281 """
282 282 Repository.create_manage_file(file, **opts)
283 283
284 284
285 285 @with_engine
286 286 def compare_model_to_db(url, repository, model, **opts):
287 287 """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
288 288
289 289 Compare the current model (assumed to be a module level variable
290 290 of type sqlalchemy.MetaData) against the current database.
291 291
292 292 NOTE: This is EXPERIMENTAL.
293 293 """ # TODO: get rid of EXPERIMENTAL label
294 294 engine = opts.pop('engine')
295 295 return ControlledSchema.compare_model_to_db(engine, model, repository)
296 296
297 297
298 298 @with_engine
299 299 def create_model(url, repository, **opts):
300 300 """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
301 301
302 302 Dump the current database as a Python model to stdout.
303 303
304 304 NOTE: This is EXPERIMENTAL.
305 305 """ # TODO: get rid of EXPERIMENTAL label
306 306 engine = opts.pop('engine')
307 307 declarative = opts.get('declarative', False)
308 308 return ControlledSchema.create_model(engine, repository, declarative)
309 309
310 310
311 311 @catch_known_errors
312 312 @with_engine
313 313 def make_update_script_for_model(url, repository, oldmodel, model, **opts):
314 314 """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
315 315
316 316 Create a script changing the old Python model to the new (current)
317 317 Python model, sending to stdout.
318 318
319 319 NOTE: This is EXPERIMENTAL.
320 320 """ # TODO: get rid of EXPERIMENTAL label
321 321 engine = opts.pop('engine')
322 322 return PythonScript.make_update_script_for_model(
323 323 engine, oldmodel, model, repository, **opts)
324 324
325 325
326 326 @with_engine
327 327 def update_db_from_model(url, repository, model, **opts):
328 328 """%prog update_db_from_model URL REPOSITORY_PATH MODEL
329 329
330 330 Modify the database to match the structure of the current Python
331 331 model. This also sets the db_version number to the latest in the
332 332 repository.
333 333
334 334 NOTE: This is EXPERIMENTAL.
335 335 """ # TODO: get rid of EXPERIMENTAL label
336 336 engine = opts.pop('engine')
337 337 schema = ControlledSchema(engine, repository)
338 338 schema.update_db_from_model(model)
339 339
340 340 @with_engine
341 341 def _migrate(url, repository, version, upgrade, err, **opts):
342 342 engine = opts.pop('engine')
343 343 url = str(engine.url)
344 344 schema = ControlledSchema(engine, repository)
345 345 version = _migrate_version(schema, version, upgrade, err)
346 346
347 347 changeset = schema.changeset(version)
348 348 for ver, change in changeset:
349 349 nextver = ver + changeset.step
350 350 log.info('%s -> %s... ', ver, nextver)
351 351
352 352 if opts.get('preview_sql'):
353 353 if isinstance(change, PythonScript):
354 354 log.info(change.preview_sql(url, changeset.step, **opts))
355 355 elif isinstance(change, SqlScript):
356 356 log.info(change.source())
357 357
358 358 elif opts.get('preview_py'):
359 359 if not isinstance(change, PythonScript):
360 360 raise exceptions.UsageError("Python source can be only displayed"
361 361 " for python migration files")
362 362 source_ver = max(ver, nextver)
363 363 module = schema.repository.version(source_ver).script().module
364 364 funcname = upgrade and "upgrade" or "downgrade"
365 365 func = getattr(module, funcname)
366 366 log.info(inspect.getsource(func))
367 367 else:
368 368 schema.runchange(ver, change, changeset.step)
369 369 log.info('done')
370 370
371 371
372 372 def _migrate_version(schema, version, upgrade, err):
373 373 if version is None:
374 374 return version
375 375 # Version is specified: ensure we're upgrading in the right direction
376 376 # (current version < target version for upgrading; reverse for down)
377 377 version = VerNum(version)
378 378 cur = schema.version
379 379 if upgrade is not None:
380 380 if upgrade:
381 381 direction = cur <= version
382 382 else:
383 383 direction = cur >= version
384 384 if not direction:
385 385 raise exceptions.KnownError(err % (cur, version))
386 386 return version
@@ -1,302 +1,302 b''
1 1 """
2 2 Code to generate a Python model from a database or differences
3 3 between a model and database.
4 4
5 5 Some of this is borrowed heavily from the AutoCode project at:
6 6 http://code.google.com/p/sqlautocode/
7 7 """
8 8
9 9 import sys
10 10 import logging
11 11
12 12 import sqlalchemy
13 13
14 14 import rhodecode.lib.dbmigrate.migrate
15 15 import rhodecode.lib.dbmigrate.migrate.changeset
16 16
17 17
18 18 log = logging.getLogger(__name__)
19 19 HEADER = """
20 20 ## File autogenerated by genmodel.py
21 21
22 22 from sqlalchemy import *
23 23 """
24 24
25 25 META_DEFINITION = "meta = MetaData()"
26 26
27 27 DECLARATIVE_DEFINITION = """
28 28 from sqlalchemy.ext import declarative
29 29
30 30 Base = declarative.declarative_base()
31 31 """
32 32
33 33
34 34 class ModelGenerator(object):
35 35 """Various transformations from an A, B diff.
36 36
37 37 In the implementation, A tends to be called the model and B
38 38 the database (although this is not true of all diffs).
39 39 The diff is directionless, but transformations apply the diff
40 40 in a particular direction, described in the method name.
41 41 """
42 42
43 43 def __init__(self, diff, engine, declarative=False):
44 44 self.diff = diff
45 45 self.engine = engine
46 46 self.declarative = declarative
47 47
48 48 def column_repr(self, col):
49 49 kwarg = []
50 50 if col.key != col.name:
51 51 kwarg.append('key')
52 52 if col.primary_key:
53 53 col.primary_key = True # otherwise it dumps it as 1
54 54 kwarg.append('primary_key')
55 55 if not col.nullable:
56 56 kwarg.append('nullable')
57 57 if col.onupdate:
58 58 kwarg.append('onupdate')
59 59 if col.default:
60 60 if col.primary_key:
61 61 # I found that PostgreSQL automatically creates a
62 62 # default value for the sequence, but let's not show
63 63 # that.
64 64 pass
65 65 else:
66 66 kwarg.append('default')
67 67 args = ['%s=%r' % (k, getattr(col, k)) for k in kwarg]
68 68
69 69 # crs: not sure if this is good idea, but it gets rid of extra
70 70 # u''
71 71 name = col.name.encode('utf8')
72 72
73 73 type_ = col.type
74 74 for cls in col.type.__class__.__mro__:
75 75 if cls.__module__ == 'sqlalchemy.types' and \
76 76 not cls.__name__.isupper():
77 77 if cls is not type_.__class__:
78 78 type_ = cls()
79 79 break
80 80
81 81 type_repr = repr(type_)
82 82 if type_repr.endswith('()'):
83 83 type_repr = type_repr[:-2]
84 84
85 85 constraints = [repr(cn) for cn in col.constraints]
86 86
87 87 data = {
88 88 'name': name,
89 89 'commonStuff': ', '.join([type_repr] + constraints + args),
90 90 }
91 91
92 92 if self.declarative:
93 93 return """%(name)s = Column(%(commonStuff)s)""" % data
94 94 else:
95 95 return """Column(%(name)r, %(commonStuff)s)""" % data
96 96
97 97 def _getTableDefn(self, table, metaName='meta'):
98 98 out = []
99 99 tableName = table.name
100 100 if self.declarative:
101 101 out.append("class %(table)s(Base):" % {'table': tableName})
102 102 out.append(" __tablename__ = '%(table)s'\n" %
103 103 {'table': tableName})
104 104 for col in table.columns:
105 105 out.append(" %s" % self.column_repr(col))
106 106 out.append('\n')
107 107 else:
108 108 out.append("%(table)s = Table('%(table)s', %(meta)s," %
109 109 {'table': tableName, 'meta': metaName})
110 110 for col in table.columns:
111 111 out.append(" %s," % self.column_repr(col))
112 112 out.append(")\n")
113 113 return out
114 114
115 115 def _get_tables(self,missingA=False,missingB=False,modified=False):
116 116 to_process = []
117 117 for bool_,names,metadata in (
118 118 (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
119 119 (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
120 120 (modified,self.diff.tables_different,self.diff.metadataA),
121 121 ):
122 122 if bool_:
123 123 for name in names:
124 124 yield metadata.tables.get(name)
125 125
126 126 def _genModelHeader(self, tables):
127 127 out = []
128 128 import_index = []
129 129
130 130 out.append(HEADER)
131 131
132 132 for table in tables:
133 133 for col in table.columns:
134 134 if "dialects" in col.type.__module__ and \
135 135 col.type.__class__ not in import_index:
136 136 out.append("from " + col.type.__module__ +
137 137 " import " + col.type.__class__.__name__)
138 138 import_index.append(col.type.__class__)
139 139
140 140 out.append("")
141 141
142 142 if self.declarative:
143 143 out.append(DECLARATIVE_DEFINITION)
144 144 else:
145 145 out.append(META_DEFINITION)
146 146 out.append("")
147 147
148 148 return out
149 149
150 150 def genBDefinition(self):
151 151 """Generates the source code for a definition of B.
152 152
153 153 Assumes a diff where A is empty.
154 154
155 155 Was: toPython. Assume database (B) is current and model (A) is empty.
156 156 """
157 157
158 158 out = []
159 159 out.extend(self._genModelHeader(self._get_tables(missingA=True)))
160 160 for table in self._get_tables(missingA=True):
161 161 out.extend(self._getTableDefn(table))
162 162 return '\n'.join(out)
163 163
164 164 def genB2AMigration(self, indent=' '):
165 165 """Generate a migration from B to A.
166 166
167 167 Was: toUpgradeDowngradePython
168 168 Assume model (A) is most current and database (B) is out-of-date.
169 169 """
170 170
171 171 decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema',
172 172 'pre_meta = MetaData()',
173 173 'post_meta = MetaData()',
174 174 ]
175 175 upgradeCommands = ['pre_meta.bind = migrate_engine',
176 176 'post_meta.bind = migrate_engine']
177 177 downgradeCommands = list(upgradeCommands)
178 178
179 179 for tn in self.diff.tables_missing_from_A:
180 180 pre_table = self.diff.metadataB.tables[tn]
181 181 decls.extend(self._getTableDefn(pre_table, metaName='pre_meta'))
182 182 upgradeCommands.append(
183 183 "pre_meta.tables[%(table)r].drop()" % {'table': tn})
184 184 downgradeCommands.append(
185 185 "pre_meta.tables[%(table)r].create()" % {'table': tn})
186 186
187 187 for tn in self.diff.tables_missing_from_B:
188 188 post_table = self.diff.metadataA.tables[tn]
189 189 decls.extend(self._getTableDefn(post_table, metaName='post_meta'))
190 190 upgradeCommands.append(
191 191 "post_meta.tables[%(table)r].create()" % {'table': tn})
192 192 downgradeCommands.append(
193 193 "post_meta.tables[%(table)r].drop()" % {'table': tn})
194 194
195 for (tn, td) in self.diff.tables_different.items():
195 for (tn, td) in list(self.diff.tables_different.items()):
196 196 if td.columns_missing_from_A or td.columns_different:
197 197 pre_table = self.diff.metadataB.tables[tn]
198 198 decls.extend(self._getTableDefn(
199 199 pre_table, metaName='pre_meta'))
200 200 if td.columns_missing_from_B or td.columns_different:
201 201 post_table = self.diff.metadataA.tables[tn]
202 202 decls.extend(self._getTableDefn(
203 203 post_table, metaName='post_meta'))
204 204
205 205 for col in td.columns_missing_from_A:
206 206 upgradeCommands.append(
207 207 'pre_meta.tables[%r].columns[%r].drop()' % (tn, col))
208 208 downgradeCommands.append(
209 209 'pre_meta.tables[%r].columns[%r].create()' % (tn, col))
210 210 for col in td.columns_missing_from_B:
211 211 upgradeCommands.append(
212 212 'post_meta.tables[%r].columns[%r].create()' % (tn, col))
213 213 downgradeCommands.append(
214 214 'post_meta.tables[%r].columns[%r].drop()' % (tn, col))
215 215 for modelCol, databaseCol, modelDecl, databaseDecl in td.columns_different:
216 216 upgradeCommands.append(
217 217 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
218 218 tn, modelCol.name, databaseCol.name))
219 219 downgradeCommands.append(
220 220 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
221 221 tn, modelCol.name, databaseCol.name))
222 222
223 223 return (
224 224 '\n'.join(decls),
225 225 '\n'.join('%s%s' % (indent, line) for line in upgradeCommands),
226 226 '\n'.join('%s%s' % (indent, line) for line in downgradeCommands))
227 227
228 228 def _db_can_handle_this_change(self,td):
229 229 """Check if the database can handle going from B to A."""
230 230
231 231 if (td.columns_missing_from_B
232 232 and not td.columns_missing_from_A
233 233 and not td.columns_different):
234 234 # Even sqlite can handle column additions.
235 235 return True
236 236 else:
237 237 return not self.engine.url.drivername.startswith('sqlite')
238 238
239 239 def runB2A(self):
240 240 """Goes from B to A.
241 241
242 242 Was: applyModel. Apply model (A) to current database (B).
243 243 """
244 244
245 245 meta = sqlalchemy.MetaData(self.engine)
246 246
247 247 for table in self._get_tables(missingA=True):
248 248 table = table.tometadata(meta)
249 249 table.drop()
250 250 for table in self._get_tables(missingB=True):
251 251 table = table.tometadata(meta)
252 252 table.create()
253 253 for modelTable in self._get_tables(modified=True):
254 254 tableName = modelTable.name
255 255 modelTable = modelTable.tometadata(meta)
256 256 dbTable = self.diff.metadataB.tables[tableName]
257 257
258 258 td = self.diff.tables_different[tableName]
259 259
260 260 if self._db_can_handle_this_change(td):
261 261
262 262 for col in td.columns_missing_from_B:
263 263 modelTable.columns[col].create()
264 264 for col in td.columns_missing_from_A:
265 265 dbTable.columns[col].drop()
266 266 # XXX handle column changes here.
267 267 else:
268 268 # Sqlite doesn't support drop column, so you have to
269 269 # do more: create temp table, copy data to it, drop
270 270 # old table, create new table, copy data back.
271 271 #
272 272 # I wonder if this is guaranteed to be unique?
273 273 tempName = '_temp_%s' % modelTable.name
274 274
275 275 def getCopyStatement():
276 276 preparer = self.engine.dialect.preparer
277 277 commonCols = []
278 278 for modelCol in modelTable.columns:
279 279 if modelCol.name in dbTable.columns:
280 280 commonCols.append(modelCol.name)
281 281 commonColsStr = ', '.join(commonCols)
282 282 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
283 283 (tableName, commonColsStr, commonColsStr, tempName)
284 284
285 285 # Move the data in one transaction, so that we don't
286 286 # leave the database in a nasty state.
287 287 connection = self.engine.connect()
288 288 trans = connection.begin()
289 289 try:
290 290 connection.execute(
291 291 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
292 292 (tempName, modelTable.name))
293 293 # make sure the drop takes place inside our
294 294 # transaction with the bind parameter
295 295 modelTable.drop(bind=connection)
296 296 modelTable.create(bind=connection)
297 297 connection.execute(getCopyStatement())
298 298 connection.execute('DROP TABLE %s' % tempName)
299 299 trans.commit()
300 300 except:
301 301 trans.rollback()
302 302 raise
@@ -1,100 +1,100 b''
1 1 """
2 2 Script to migrate repository from sqlalchemy <= 0.4.4 to the new
3 3 repository schema. This shouldn't use any other migrate modules, so
4 4 that it can work in any version.
5 5 """
6 6
7 7 import os
8 8 import sys
9 9 import logging
10 10
11 11 log = logging.getLogger(__name__)
12 12
13 13
14 14 def usage():
15 15 """Gives usage information."""
16 print("""Usage: %(prog)s repository-to-migrate
16 print(("""Usage: %(prog)s repository-to-migrate
17 17
18 18 Upgrade your repository to the new flat format.
19 19
20 20 NOTE: You should probably make a backup before running this.
21 """ % {'prog': sys.argv[0]})
21 """ % {'prog': sys.argv[0]}))
22 22
23 23 sys.exit(1)
24 24
25 25
26 26 def delete_file(filepath):
27 27 """Deletes a file and prints a message."""
28 28 log.info('Deleting file: %s', filepath)
29 29 os.remove(filepath)
30 30
31 31
32 32 def move_file(src, tgt):
33 33 """Moves a file and prints a message."""
34 34 log.info('Moving file %s to %s', src, tgt)
35 35 if os.path.exists(tgt):
36 36 raise Exception(
37 37 'Cannot move file %s because target %s already exists' % \
38 38 (src, tgt))
39 39 os.rename(src, tgt)
40 40
41 41
42 42 def delete_directory(dirpath):
43 43 """Delete a directory and print a message."""
44 44 log.info('Deleting directory: %s', dirpath)
45 45 os.rmdir(dirpath)
46 46
47 47
48 48 def migrate_repository(repos):
49 49 """Does the actual migration to the new repository format."""
50 50 log.info('Migrating repository at: %s to new format', repos)
51 51 versions = '%s/versions' % repos
52 52 dirs = os.listdir(versions)
53 53 # Only use int's in list.
54 54 numdirs = [int(dirname) for dirname in dirs if dirname.isdigit()]
55 55 numdirs.sort() # Sort list.
56 56 for dirname in numdirs:
57 57 origdir = '%s/%s' % (versions, dirname)
58 58 log.info('Working on directory: %s', origdir)
59 59 files = os.listdir(origdir)
60 60 files.sort()
61 61 for filename in files:
62 62 # Delete compiled Python files.
63 63 if filename.endswith('.pyc') or filename.endswith('.pyo'):
64 64 delete_file('%s/%s' % (origdir, filename))
65 65
66 66 # Delete empty __init__.py files.
67 67 origfile = '%s/__init__.py' % origdir
68 68 if os.path.exists(origfile) and len(open(origfile).read()) == 0:
69 69 delete_file(origfile)
70 70
71 71 # Move sql upgrade scripts.
72 72 if filename.endswith('.sql'):
73 73 version, dbms, operation = filename.split('.', 3)[0:3]
74 74 origfile = '%s/%s' % (origdir, filename)
75 75 # For instance: 2.postgres.upgrade.sql ->
76 76 # 002_postgres_upgrade.sql
77 77 tgtfile = '%s/%03d_%s_%s.sql' % (
78 78 versions, int(version), dbms, operation)
79 79 move_file(origfile, tgtfile)
80 80
81 81 # Move Python upgrade script.
82 82 pyfile = '%s.py' % dirname
83 83 pyfilepath = '%s/%s' % (origdir, pyfile)
84 84 if os.path.exists(pyfilepath):
85 85 tgtfile = '%s/%03d.py' % (versions, int(dirname))
86 86 move_file(pyfilepath, tgtfile)
87 87
88 88 # Try to remove directory. Will fail if it's not empty.
89 89 delete_directory(origdir)
90 90
91 91
92 92 def main():
93 93 """Main function to be called when using this script."""
94 94 if len(sys.argv) != 2:
95 95 usage()
96 96 migrate_repository(sys.argv[1])
97 97
98 98
99 99 if __name__ == '__main__':
100 100 main()
@@ -1,243 +1,243 b''
1 1 """
2 2 SQLAlchemy migrate repository management.
3 3 """
4 4 import os
5 5 import shutil
6 6 import string
7 7 import logging
8 8
9 9 from pkg_resources import resource_filename
10 10 from tempita import Template as TempitaTemplate
11 11
12 12 from rhodecode.lib.dbmigrate.migrate import exceptions
13 13 from rhodecode.lib.dbmigrate.migrate.versioning import version, pathed, cfgparse
14 14 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
15 15 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
16 16
17 17
18 18 log = logging.getLogger(__name__)
19 19
20 20 class Changeset(dict):
21 21 """A collection of changes to be applied to a database.
22 22
23 23 Changesets are bound to a repository and manage a set of
24 24 scripts from that repository.
25 25
26 26 Behaves like a dict, for the most part. Keys are ordered based on step value.
27 27 """
28 28
29 29 def __init__(self, start, *changes, **k):
30 30 """
31 31 Give a start version; step must be explicitly stated.
32 32 """
33 33 self.step = k.pop('step', 1)
34 34 self.start = version.VerNum(start)
35 35 self.end = self.start
36 36 for change in changes:
37 37 self.add(change)
38 38
39 39 def __iter__(self):
40 return iter(self.items())
40 return iter(list(self.items()))
41 41
42 42 def keys(self):
43 43 """
44 44 In a series of upgrades x -> y, keys are version x. Sorted.
45 45 """
46 ret = super(Changeset, self).keys()
46 ret = list(super(Changeset, self).keys())
47 47 # Reverse order if downgrading
48 48 ret.sort(reverse=(self.step < 1))
49 49 return ret
50 50
51 51 def values(self):
52 return [self[k] for k in self.keys()]
52 return [self[k] for k in list(self.keys())]
53 53
54 54 def items(self):
55 return zip(self.keys(), self.values())
55 return list(zip(list(self.keys()), list(self.values())))
56 56
57 57 def add(self, change):
58 58 """Add new change to changeset"""
59 59 key = self.end
60 60 self.end += self.step
61 61 self[key] = change
62 62
63 63 def run(self, *p, **k):
64 64 """Run the changeset scripts"""
65 65 for version, script in self:
66 66 script.run(*p, **k)
67 67
68 68
69 69 class Repository(pathed.Pathed):
70 70 """A project's change script repository"""
71 71
72 72 _config = 'migrate.cfg'
73 73 _versions = 'versions'
74 74
75 75 def __init__(self, path):
76 76 log.debug('Loading repository %s...', path)
77 77 self.verify(path)
78 78 super(Repository, self).__init__(path)
79 79 self.config = cfgparse.Config(os.path.join(self.path, self._config))
80 80 self.versions = version.Collection(os.path.join(self.path,
81 81 self._versions))
82 82 log.debug('Repository %s loaded successfully', path)
83 83 log.debug('Config: %r', self.config.to_dict())
84 84
85 85 @classmethod
86 86 def verify(cls, path):
87 87 """
88 88 Ensure the target path is a valid repository.
89 89
90 90 :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
91 91 """
92 92 # Ensure the existence of required files
93 93 try:
94 94 cls.require_found(path)
95 95 cls.require_found(os.path.join(path, cls._config))
96 96 cls.require_found(os.path.join(path, cls._versions))
97 97 except exceptions.PathNotFoundError as e:
98 98 raise exceptions.InvalidRepositoryError(path)
99 99
100 100 @classmethod
101 101 def prepare_config(cls, tmpl_dir, name, options=None):
102 102 """
103 103 Prepare a project configuration file for a new project.
104 104
105 105 :param tmpl_dir: Path to Repository template
106 106 :param config_file: Name of the config file in Repository template
107 107 :param name: Repository name
108 108 :type tmpl_dir: string
109 109 :type config_file: string
110 110 :type name: string
111 111 :returns: Populated config file
112 112 """
113 113 if options is None:
114 114 options = {}
115 115 options.setdefault('version_table', 'migrate_version')
116 116 options.setdefault('repository_id', name)
117 117 options.setdefault('required_dbs', [])
118 118 options.setdefault('use_timestamp_numbering', False)
119 119
120 120 with open(os.path.join(tmpl_dir, cls._config)) as f:
121 121 tmpl = f.read()
122 122 ret = TempitaTemplate(tmpl).substitute(options)
123 123
124 124 # cleanup
125 125 del options['__template_name__']
126 126
127 127 return ret
128 128
129 129 @classmethod
130 130 def create(cls, path, name, **opts):
131 131 """Create a repository at a specified path"""
132 132 cls.require_notfound(path)
133 133 theme = opts.pop('templates_theme', None)
134 134 t_path = opts.pop('templates_path', None)
135 135
136 136 # Create repository
137 137 tmpl_dir = Template(t_path).get_repository(theme=theme)
138 138 shutil.copytree(tmpl_dir, path)
139 139
140 140 # Edit config defaults
141 141 config_text = cls.prepare_config(tmpl_dir, name, options=opts)
142 142 with open(os.path.join(path, cls._config), 'w') as fd:
143 143 fd.write(config_text)
144 144
145 145 opts['repository_name'] = name
146 146
147 147 # Create a management script
148 148 manager = os.path.join(path, 'manage.py')
149 149 Repository.create_manage_file(manager, templates_theme=theme,
150 150 templates_path=t_path, **opts)
151 151
152 152 return cls(path)
153 153
154 154 def create_script(self, description, **k):
155 155 """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
156 156
157 157 k['use_timestamp_numbering'] = self.use_timestamp_numbering
158 158 self.versions.create_new_python_version(description, **k)
159 159
160 160 def create_script_sql(self, database, description, **k):
161 161 """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
162 162 k['use_timestamp_numbering'] = self.use_timestamp_numbering
163 163 self.versions.create_new_sql_version(database, description, **k)
164 164
165 165 @property
166 166 def latest(self):
167 167 """API to :attr:`migrate.versioning.version.Collection.latest`"""
168 168 return self.versions.latest
169 169
170 170 @property
171 171 def version_table(self):
172 172 """Returns version_table name specified in config"""
173 173 return self.config.get('db_settings', 'version_table')
174 174
175 175 @property
176 176 def id(self):
177 177 """Returns repository id specified in config"""
178 178 return self.config.get('db_settings', 'repository_id')
179 179
180 180 @property
181 181 def use_timestamp_numbering(self):
182 182 """Returns use_timestamp_numbering specified in config"""
183 183 if self.config.has_option('db_settings', 'use_timestamp_numbering'):
184 184 return self.config.getboolean('db_settings', 'use_timestamp_numbering')
185 185 return False
186 186
187 187 def version(self, *p, **k):
188 188 """API to :attr:`migrate.versioning.version.Collection.version`"""
189 189 return self.versions.version(*p, **k)
190 190
191 191 @classmethod
192 192 def clear(cls):
193 193 # TODO: deletes repo
194 194 super(Repository, cls).clear()
195 195 version.Collection.clear()
196 196
197 197 def changeset(self, database, start, end=None):
198 198 """Create a changeset to migrate this database from ver. start to end/latest.
199 199
200 200 :param database: name of database to generate changeset
201 201 :param start: version to start at
202 202 :param end: version to end at (latest if None given)
203 203 :type database: string
204 204 :type start: int
205 205 :type end: int
206 206 :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
207 207 """
208 208 start = version.VerNum(start)
209 209
210 210 if end is None:
211 211 end = self.latest
212 212 else:
213 213 end = version.VerNum(end)
214 214
215 215 if start <= end:
216 216 step = 1
217 217 range_mod = 1
218 218 op = 'upgrade'
219 219 else:
220 220 step = -1
221 221 range_mod = 0
222 222 op = 'downgrade'
223 223
224 versions = range(int(start) + range_mod, int(end) + range_mod, step)
224 versions = list(range(int(start) + range_mod, int(end) + range_mod, step))
225 225 changes = [self.version(v).script(database, op) for v in versions]
226 226 ret = Changeset(start, step=step, *changes)
227 227 return ret
228 228
229 229 @classmethod
230 230 def create_manage_file(cls, file_, **opts):
231 231 """Create a project management script (manage.py)
232 232
233 233 :param file_: Destination file to be written
234 234 :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
235 235 """
236 236 mng_file = Template(opts.pop('templates_path', None))\
237 237 .get_manage(theme=opts.pop('templates_theme', None))
238 238
239 239 with open(mng_file) as f:
240 240 tmpl = f.read()
241 241
242 242 with open(file_, 'w') as fd:
243 243 fd.write(TempitaTemplate(tmpl).substitute(opts))
@@ -1,221 +1,221 b''
1 1 """
2 2 Database schema version management.
3 3 """
4 4 import sys
5 5 import logging
6 6
7 7 from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
8 8 create_engine)
9 9 from sqlalchemy.sql import and_
10 10 from sqlalchemy import exc as sa_exceptions
11 11 from sqlalchemy.sql import bindparam
12 12
13 13 from rhodecode.lib.dbmigrate.migrate import exceptions
14 14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_07
15 15 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
16 16 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
17 17 from rhodecode.lib.dbmigrate.migrate.versioning.util import load_model
18 18 from rhodecode.lib.dbmigrate.migrate.versioning.version import VerNum
19 19
20 20
21 21 log = logging.getLogger(__name__)
22 22
23 23
24 24 class ControlledSchema(object):
25 25 """A database under version control"""
26 26
27 27 def __init__(self, engine, repository):
28 28 if isinstance(repository, str):
29 29 repository = Repository(repository)
30 30 self.engine = engine
31 31 self.repository = repository
32 32 self.meta = MetaData(engine)
33 33 self.load()
34 34
35 35 def __eq__(self, other):
36 36 """Compare two schemas by repositories and versions"""
37 37 return (self.repository is other.repository \
38 38 and self.version == other.version)
39 39
40 40 def load(self):
41 41 """Load controlled schema version info from DB"""
42 42 tname = self.repository.version_table
43 43 try:
44 44 if not hasattr(self, 'table') or self.table is None:
45 45 self.table = Table(tname, self.meta, autoload=True)
46 46
47 47 result = self.engine.execute(self.table.select(
48 48 self.table.c.repository_id == str(self.repository.id)))
49 49
50 50 data = list(result)[0]
51 51 except:
52 52 cls, exc, tb = sys.exc_info()
53 raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
53 raise exceptions.DatabaseNotControlledError(exc.__str__()).with_traceback(tb)
54 54
55 55 self.version = data['version']
56 56 return data
57 57
58 58 def drop(self):
59 59 """
60 60 Remove version control from a database.
61 61 """
62 62 if SQLA_07:
63 63 try:
64 64 self.table.drop()
65 65 except sa_exceptions.DatabaseError:
66 66 raise exceptions.DatabaseNotControlledError(str(self.table))
67 67 else:
68 68 try:
69 69 self.table.drop()
70 70 except (sa_exceptions.SQLError):
71 71 raise exceptions.DatabaseNotControlledError(str(self.table))
72 72
73 73 def changeset(self, version=None):
74 74 """API to Changeset creation.
75 75
76 76 Uses self.version for start version and engine.name
77 77 to get database name.
78 78 """
79 79 database = self.engine.name
80 80 start_ver = self.version
81 81 changeset = self.repository.changeset(database, start_ver, version)
82 82 return changeset
83 83
84 84 def runchange(self, ver, change, step):
85 85 startver = ver
86 86 endver = ver + step
87 87 # Current database version must be correct! Don't run if corrupt!
88 88 if self.version != startver:
89 89 raise exceptions.InvalidVersionError("%s is not %s" % \
90 90 (self.version, startver))
91 91 # Run the change
92 92 change.run(self.engine, step)
93 93
94 94 # Update/refresh database version
95 95 self.update_repository_table(startver, endver)
96 96 self.load()
97 97
98 98 def update_repository_table(self, startver, endver):
99 99 """Update version_table with new information"""
100 100 update = self.table.update(and_(self.table.c.version == int(startver),
101 101 self.table.c.repository_id == str(self.repository.id)))
102 102 self.engine.execute(update, version=int(endver))
103 103
104 104 def upgrade(self, version=None):
105 105 """
106 106 Upgrade (or downgrade) to a specified version, or latest version.
107 107 """
108 108 changeset = self.changeset(version)
109 109 for ver, change in changeset:
110 110 self.runchange(ver, change, changeset.step)
111 111
112 112 def update_db_from_model(self, model):
113 113 """
114 114 Modify the database to match the structure of the current Python model.
115 115 """
116 116 model = load_model(model)
117 117
118 118 diff = schemadiff.getDiffOfModelAgainstDatabase(
119 119 model, self.engine, excludeTables=[self.repository.version_table]
120 120 )
121 121 genmodel.ModelGenerator(diff,self.engine).runB2A()
122 122
123 123 self.update_repository_table(self.version, int(self.repository.latest))
124 124
125 125 self.load()
126 126
127 127 @classmethod
128 128 def create(cls, engine, repository, version=None):
129 129 """
130 130 Declare a database to be under a repository's version control.
131 131
132 132 :raises: :exc:`DatabaseAlreadyControlledError`
133 133 :returns: :class:`ControlledSchema`
134 134 """
135 135 # Confirm that the version # is valid: positive, integer,
136 136 # exists in repos
137 137 if isinstance(repository, str):
138 138 repository = Repository(repository)
139 139 version = cls._validate_version(repository, version)
140 140 table = cls._create_table_version(engine, repository, version)
141 141 # TODO: history table
142 142 # Load repository information and return
143 143 return cls(engine, repository)
144 144
145 145 @classmethod
146 146 def _validate_version(cls, repository, version):
147 147 """
148 148 Ensures this is a valid version number for this repository.
149 149
150 150 :raises: :exc:`InvalidVersionError` if invalid
151 151 :return: valid version number
152 152 """
153 153 if version is None:
154 154 version = 0
155 155 try:
156 156 version = VerNum(version) # raises valueerror
157 157 if version < 0 or version > repository.latest:
158 158 raise ValueError()
159 159 except ValueError:
160 160 raise exceptions.InvalidVersionError(version)
161 161 return version
162 162
163 163 @classmethod
164 164 def _create_table_version(cls, engine, repository, version):
165 165 """
166 166 Creates the versioning table in a database.
167 167
168 168 :raises: :exc:`DatabaseAlreadyControlledError`
169 169 """
170 170 # Create tables
171 171 tname = repository.version_table
172 172 meta = MetaData(engine)
173 173
174 174 table = Table(
175 175 tname, meta,
176 176 Column('repository_id', String(250), primary_key=True),
177 177 Column('repository_path', Text),
178 178 Column('version', Integer), )
179 179
180 180 # there can be multiple repositories/schemas in the same db
181 181 if not table.exists():
182 182 table.create()
183 183
184 184 # test for existing repository_id
185 185 s = table.select(table.c.repository_id == bindparam("repository_id"))
186 186 result = engine.execute(s, repository_id=repository.id)
187 187 if result.fetchone():
188 188 raise exceptions.DatabaseAlreadyControlledError
189 189
190 190 # Insert data
191 191 engine.execute(table.insert().values(
192 192 repository_id=repository.id,
193 193 repository_path=repository.path,
194 194 version=int(version)))
195 195 return table
196 196
197 197 @classmethod
198 198 def compare_model_to_db(cls, engine, model, repository):
199 199 """
200 200 Compare the current model against the current database.
201 201 """
202 202 if isinstance(repository, str):
203 203 repository = Repository(repository)
204 204 model = load_model(model)
205 205
206 206 diff = schemadiff.getDiffOfModelAgainstDatabase(
207 207 model, engine, excludeTables=[repository.version_table])
208 208 return diff
209 209
210 210 @classmethod
211 211 def create_model(cls, engine, repository, declarative=False):
212 212 """
213 213 Dump the current database as a Python model.
214 214 """
215 215 if isinstance(repository, str):
216 216 repository = Repository(repository)
217 217
218 218 diff = schemadiff.getDiffOfModelAgainstDatabase(
219 219 MetaData(), engine, excludeTables=[repository.version_table]
220 220 )
221 221 return genmodel.ModelGenerator(diff, engine, declarative).genBDefinition()
@@ -1,299 +1,299 b''
1 1 """
2 2 Schema differencing support.
3 3 """
4 4
5 5 import logging
6 6 import sqlalchemy
7 7
8 8 from sqlalchemy.types import Float
9 9
10 10 log = logging.getLogger(__name__)
11 11
12 12
13 13 def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None):
14 14 """
15 15 Return differences of model against database.
16 16
17 17 :return: object which will evaluate to :keyword:`True` if there \
18 18 are differences else :keyword:`False`.
19 19 """
20 20 db_metadata = sqlalchemy.MetaData(engine)
21 21 db_metadata.reflect()
22 22
23 23 # sqlite will include a dynamically generated 'sqlite_sequence' table if
24 24 # there are autoincrement sequences in the database; this should not be
25 25 # compared.
26 26 if engine.dialect.name == 'sqlite':
27 27 if 'sqlite_sequence' in db_metadata.tables:
28 28 db_metadata.remove(db_metadata.tables['sqlite_sequence'])
29 29
30 30 return SchemaDiff(metadata, db_metadata,
31 31 labelA='model',
32 32 labelB='database',
33 33 excludeTables=excludeTables)
34 34
35 35
36 36 def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None):
37 37 """
38 38 Return differences of model against another model.
39 39
40 40 :return: object which will evaluate to :keyword:`True` if there \
41 41 are differences else :keyword:`False`.
42 42 """
43 43 return SchemaDiff(metadataA, metadataB, excludeTables=excludeTables)
44 44
45 45
46 46 class ColDiff(object):
47 47 """
48 48 Container for differences in one :class:`~sqlalchemy.schema.Column`
49 49 between two :class:`~sqlalchemy.schema.Table` instances, ``A``
50 50 and ``B``.
51 51
52 52 .. attribute:: col_A
53 53
54 54 The :class:`~sqlalchemy.schema.Column` object for A.
55 55
56 56 .. attribute:: col_B
57 57
58 58 The :class:`~sqlalchemy.schema.Column` object for B.
59 59
60 60 .. attribute:: type_A
61 61
62 62 The most generic type of the :class:`~sqlalchemy.schema.Column`
63 63 object in A.
64 64
65 65 .. attribute:: type_B
66 66
67 67 The most generic type of the :class:`~sqlalchemy.schema.Column`
68 68 object in A.
69 69
70 70 """
71 71
72 72 diff = False
73 73
74 74 def __init__(self,col_A,col_B):
75 75 self.col_A = col_A
76 76 self.col_B = col_B
77 77
78 78 self.type_A = col_A.type
79 79 self.type_B = col_B.type
80 80
81 81 self.affinity_A = self.type_A._type_affinity
82 82 self.affinity_B = self.type_B._type_affinity
83 83
84 84 if self.affinity_A is not self.affinity_B:
85 85 self.diff = True
86 86 return
87 87
88 88 if isinstance(self.type_A,Float) or isinstance(self.type_B,Float):
89 89 if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)):
90 90 self.diff=True
91 91 return
92 92
93 93 for attr in ('precision','scale','length'):
94 94 A = getattr(self.type_A,attr,None)
95 95 B = getattr(self.type_B,attr,None)
96 96 if not (A is None or B is None) and A!=B:
97 97 self.diff=True
98 98 return
99 99
100 def __nonzero__(self):
100 def __bool__(self):
101 101 return self.diff
102 102
103 103 __bool__ = __nonzero__
104 104
105 105
106 106 class TableDiff(object):
107 107 """
108 108 Container for differences in one :class:`~sqlalchemy.schema.Table`
109 109 between two :class:`~sqlalchemy.schema.MetaData` instances, ``A``
110 110 and ``B``.
111 111
112 112 .. attribute:: columns_missing_from_A
113 113
114 114 A sequence of column names that were found in B but weren't in
115 115 A.
116 116
117 117 .. attribute:: columns_missing_from_B
118 118
119 119 A sequence of column names that were found in A but weren't in
120 120 B.
121 121
122 122 .. attribute:: columns_different
123 123
124 124 A dictionary containing information about columns that were
125 125 found to be different.
126 126 It maps column names to a :class:`ColDiff` objects describing the
127 127 differences found.
128 128 """
129 129 __slots__ = (
130 130 'columns_missing_from_A',
131 131 'columns_missing_from_B',
132 132 'columns_different',
133 133 )
134 134
135 def __nonzero__(self):
135 def __bool__(self):
136 136 return bool(
137 137 self.columns_missing_from_A or
138 138 self.columns_missing_from_B or
139 139 self.columns_different
140 140 )
141 141
142 142 __bool__ = __nonzero__
143 143
144 144 class SchemaDiff(object):
145 145 """
146 146 Compute the difference between two :class:`~sqlalchemy.schema.MetaData`
147 147 objects.
148 148
149 149 The string representation of a :class:`SchemaDiff` will summarise
150 150 the changes found between the two
151 151 :class:`~sqlalchemy.schema.MetaData` objects.
152 152
153 153 The length of a :class:`SchemaDiff` will give the number of
154 154 changes found, enabling it to be used much like a boolean in
155 155 expressions.
156 156
157 157 :param metadataA:
158 158 First :class:`~sqlalchemy.schema.MetaData` to compare.
159 159
160 160 :param metadataB:
161 161 Second :class:`~sqlalchemy.schema.MetaData` to compare.
162 162
163 163 :param labelA:
164 164 The label to use in messages about the first
165 165 :class:`~sqlalchemy.schema.MetaData`.
166 166
167 167 :param labelB:
168 168 The label to use in messages about the second
169 169 :class:`~sqlalchemy.schema.MetaData`.
170 170
171 171 :param excludeTables:
172 172 A sequence of table names to exclude.
173 173
174 174 .. attribute:: tables_missing_from_A
175 175
176 176 A sequence of table names that were found in B but weren't in
177 177 A.
178 178
179 179 .. attribute:: tables_missing_from_B
180 180
181 181 A sequence of table names that were found in A but weren't in
182 182 B.
183 183
184 184 .. attribute:: tables_different
185 185
186 186 A dictionary containing information about tables that were found
187 187 to be different.
188 188 It maps table names to a :class:`TableDiff` objects describing the
189 189 differences found.
190 190 """
191 191
192 192 def __init__(self,
193 193 metadataA, metadataB,
194 194 labelA='metadataA',
195 195 labelB='metadataB',
196 196 excludeTables=None):
197 197
198 198 self.metadataA, self.metadataB = metadataA, metadataB
199 199 self.labelA, self.labelB = labelA, labelB
200 200 self.label_width = max(len(labelA),len(labelB))
201 201 excludeTables = set(excludeTables or [])
202 202
203 203 A_table_names = set(metadataA.tables.keys())
204 204 B_table_names = set(metadataB.tables.keys())
205 205
206 206 self.tables_missing_from_A = sorted(
207 207 B_table_names - A_table_names - excludeTables
208 208 )
209 209 self.tables_missing_from_B = sorted(
210 210 A_table_names - B_table_names - excludeTables
211 211 )
212 212
213 213 self.tables_different = {}
214 214 for table_name in A_table_names.intersection(B_table_names):
215 215
216 216 td = TableDiff()
217 217
218 218 A_table = metadataA.tables[table_name]
219 219 B_table = metadataB.tables[table_name]
220 220
221 221 A_column_names = set(A_table.columns.keys())
222 222 B_column_names = set(B_table.columns.keys())
223 223
224 224 td.columns_missing_from_A = sorted(
225 225 B_column_names - A_column_names
226 226 )
227 227
228 228 td.columns_missing_from_B = sorted(
229 229 A_column_names - B_column_names
230 230 )
231 231
232 232 td.columns_different = {}
233 233
234 234 for col_name in A_column_names.intersection(B_column_names):
235 235
236 236 cd = ColDiff(
237 237 A_table.columns.get(col_name),
238 238 B_table.columns.get(col_name)
239 239 )
240 240
241 241 if cd:
242 242 td.columns_different[col_name]=cd
243 243
244 244 # XXX - index and constraint differences should
245 245 # be checked for here
246 246
247 247 if td:
248 248 self.tables_different[table_name]=td
249 249
250 250 def __str__(self):
251 251 """ Summarize differences. """
252 252 out = []
253 253 column_template =' %%%is: %%r' % self.label_width
254 254
255 255 for names,label in (
256 256 (self.tables_missing_from_A,self.labelA),
257 257 (self.tables_missing_from_B,self.labelB),
258 258 ):
259 259 if names:
260 260 out.append(
261 261 ' tables missing from %s: %s' % (
262 262 label,', '.join(sorted(names))
263 263 )
264 264 )
265 265
266 266 for name,td in sorted(self.tables_different.items()):
267 267 out.append(
268 268 ' table with differences: %s' % name
269 269 )
270 270 for names,label in (
271 271 (td.columns_missing_from_A,self.labelA),
272 272 (td.columns_missing_from_B,self.labelB),
273 273 ):
274 274 if names:
275 275 out.append(
276 276 ' %s missing these columns: %s' % (
277 277 label,', '.join(sorted(names))
278 278 )
279 279 )
280 for name,cd in td.columns_different.items():
280 for name,cd in list(td.columns_different.items()):
281 281 out.append(' column with differences: %s' % name)
282 282 out.append(column_template % (self.labelA,cd.col_A))
283 283 out.append(column_template % (self.labelB,cd.col_B))
284 284
285 285 if out:
286 286 out.insert(0, 'Schema diffs:')
287 287 return '\n'.join(out)
288 288 else:
289 289 return 'No schema diffs'
290 290
291 291 def __len__(self):
292 292 """
293 293 Used in bool evaluation, return of 0 means no diffs.
294 294 """
295 295 return (
296 296 len(self.tables_missing_from_A) +
297 297 len(self.tables_missing_from_B) +
298 298 len(self.tables_different)
299 299 )
@@ -1,215 +1,215 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3
4 4 """The migrate command-line tool."""
5 5
6 6 import sys
7 7 import inspect
8 8 import logging
9 9 from optparse import OptionParser, BadOptionError
10 10
11 11 from rhodecode.lib.dbmigrate.migrate import exceptions
12 12 from rhodecode.lib.dbmigrate.migrate.versioning import api
13 13 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
14 14 from rhodecode.lib.dbmigrate.migrate.versioning.util import asbool
15 15
16 16
17 17 alias = {
18 18 's': api.script,
19 19 'vc': api.version_control,
20 20 'dbv': api.db_version,
21 21 'v': api.version,
22 22 }
23 23
24 24 def alias_setup():
25 25 global alias
26 for key, val in alias.items():
26 for key, val in list(alias.items()):
27 27 setattr(api, key, val)
28 28 alias_setup()
29 29
30 30
31 31 class PassiveOptionParser(OptionParser):
32 32
33 33 def _process_args(self, largs, rargs, values):
34 34 """little hack to support all --some_option=value parameters"""
35 35
36 36 while rargs:
37 37 arg = rargs[0]
38 38 if arg == "--":
39 39 del rargs[0]
40 40 return
41 41 elif arg[0:2] == "--":
42 42 # if parser does not know about the option
43 43 # pass it along (make it anonymous)
44 44 try:
45 45 opt = arg.split('=', 1)[0]
46 46 self._match_long_opt(opt)
47 47 except BadOptionError:
48 48 largs.append(arg)
49 49 del rargs[0]
50 50 else:
51 51 self._process_long_opt(rargs, values)
52 52 elif arg[:1] == "-" and len(arg) > 1:
53 53 self._process_short_opts(rargs, values)
54 54 elif self.allow_interspersed_args:
55 55 largs.append(arg)
56 56 del rargs[0]
57 57
58 58 def main(argv=None, **kwargs):
59 59 """Shell interface to :mod:`migrate.versioning.api`.
60 60
61 61 kwargs are default options that can be overriden with passing
62 62 --some_option as command line option
63 63
64 64 :param disable_logging: Let migrate configure logging
65 65 :type disable_logging: bool
66 66 """
67 67 if argv is not None:
68 68 argv = argv
69 69 else:
70 70 argv = list(sys.argv[1:])
71 71 commands = list(api.__all__)
72 72 commands.sort()
73 73
74 74 usage = """%%prog COMMAND ...
75 75
76 76 Available commands:
77 77 %s
78 78
79 79 Enter "%%prog help COMMAND" for information on a particular command.
80 80 """ % '\n\t'.join(["%s - %s" % (command.ljust(28), api.command_desc.get(command)) for command in commands])
81 81
82 82 parser = PassiveOptionParser(usage=usage)
83 83 parser.add_option("-d", "--debug",
84 84 action="store_true",
85 85 dest="debug",
86 86 default=False,
87 87 help="Shortcut to turn on DEBUG mode for logging")
88 88 parser.add_option("-q", "--disable_logging",
89 89 action="store_true",
90 90 dest="disable_logging",
91 91 default=False,
92 92 help="Use this option to disable logging configuration")
93 93 help_commands = ['help', '-h', '--help']
94 94 HELP = False
95 95
96 96 try:
97 97 command = argv.pop(0)
98 98 if command in help_commands:
99 99 HELP = True
100 100 command = argv.pop(0)
101 101 except IndexError:
102 102 parser.print_help()
103 103 return
104 104
105 105 command_func = getattr(api, command, None)
106 106 if command_func is None or command.startswith('_'):
107 107 parser.error("Invalid command %s" % command)
108 108
109 109 parser.set_usage(inspect.getdoc(command_func))
110 110 f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
111 111 for arg in f_args:
112 112 parser.add_option(
113 113 "--%s" % arg,
114 114 dest=arg,
115 115 action='store',
116 116 type="string")
117 117
118 118 # display help of the current command
119 119 if HELP:
120 120 parser.print_help()
121 121 return
122 122
123 123 options, args = parser.parse_args(argv)
124 124
125 125 # override kwargs with anonymous parameters
126 126 override_kwargs = {}
127 127 for arg in list(args):
128 128 if arg.startswith('--'):
129 129 args.remove(arg)
130 130 if '=' in arg:
131 131 opt, value = arg[2:].split('=', 1)
132 132 else:
133 133 opt = arg[2:]
134 134 value = True
135 135 override_kwargs[opt] = value
136 136
137 137 # override kwargs with options if user is overwriting
138 for key, value in options.__dict__.items():
138 for key, value in list(options.__dict__.items()):
139 139 if value is not None:
140 140 override_kwargs[key] = value
141 141
142 142 # arguments that function accepts without passed kwargs
143 143 f_required = list(f_args)
144 144 candidates = dict(kwargs)
145 145 candidates.update(override_kwargs)
146 for key, value in candidates.items():
146 for key, value in list(candidates.items()):
147 147 if key in f_args:
148 148 f_required.remove(key)
149 149
150 150 # map function arguments to parsed arguments
151 151 for arg in args:
152 152 try:
153 153 kw = f_required.pop(0)
154 154 except IndexError:
155 155 parser.error("Too many arguments for command %s: %s" % (command,
156 156 arg))
157 157 kwargs[kw] = arg
158 158
159 159 # apply overrides
160 160 kwargs.update(override_kwargs)
161 161
162 162 # configure options
163 for key, value in options.__dict__.items():
163 for key, value in list(options.__dict__.items()):
164 164 kwargs.setdefault(key, value)
165 165
166 166 # configure logging
167 167 if not asbool(kwargs.pop('disable_logging', False)):
168 168 # filter to log =< INFO into stdout and rest to stderr
169 169 class SingleLevelFilter(logging.Filter):
170 170 def __init__(self, min=None, max=None):
171 171 self.min = min or 0
172 172 self.max = max or 100
173 173
174 174 def filter(self, record):
175 175 return self.min <= record.levelno <= self.max
176 176
177 177 logger = logging.getLogger()
178 178 h1 = logging.StreamHandler(sys.stdout)
179 179 f1 = SingleLevelFilter(max=logging.INFO)
180 180 h1.addFilter(f1)
181 181 h2 = logging.StreamHandler(sys.stderr)
182 182 f2 = SingleLevelFilter(min=logging.WARN)
183 183 h2.addFilter(f2)
184 184 logger.addHandler(h1)
185 185 logger.addHandler(h2)
186 186
187 187 if options.debug:
188 188 logger.setLevel(logging.DEBUG)
189 189 else:
190 190 logger.setLevel(logging.INFO)
191 191
192 192 log = logging.getLogger(__name__)
193 193
194 194 # check if all args are given
195 195 try:
196 196 num_defaults = len(f_defaults)
197 197 except TypeError:
198 198 num_defaults = 0
199 199 f_args_default = f_args[len(f_args) - num_defaults:]
200 200 required = list(set(f_required) - set(f_args_default))
201 201 required.sort()
202 202 if required:
203 203 parser.error("Not enough arguments for command %s: %s not specified" \
204 204 % (command, ', '.join(required)))
205 205
206 206 # handle command
207 207 try:
208 208 ret = command_func(**kwargs)
209 209 if ret is not None:
210 210 log.info(ret)
211 211 except (exceptions.UsageError, exceptions.KnownError) as e:
212 212 parser.error(e.args[0])
213 213
214 214 if __name__ == "__main__":
215 215 main()
@@ -1,180 +1,180 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3 """.. currentmodule:: migrate.versioning.util"""
4 4
5 5 import warnings
6 6 import logging
7 7 from decorator import decorator
8 8 from pkg_resources import EntryPoint
9 9
10 10 from sqlalchemy import create_engine
11 11 from sqlalchemy.engine import Engine
12 12 from sqlalchemy.pool import StaticPool
13 13
14 14 from rhodecode.lib.dbmigrate.migrate import exceptions
15 15 from rhodecode.lib.dbmigrate.migrate.versioning.util.keyedinstance import KeyedInstance
16 16 from rhodecode.lib.dbmigrate.migrate.versioning.util.importpath import import_path
17 17
18 18
19 19 log = logging.getLogger(__name__)
20 20
21 21
22 22 def load_model(dotted_name):
23 23 """Import module and use module-level variable".
24 24
25 25 :param dotted_name: path to model in form of string: ``some.python.module:Class``
26 26
27 27 .. versionchanged:: 0.5.4
28 28
29 29 """
30 30 if isinstance(dotted_name, str):
31 31 if ':' not in dotted_name:
32 32 # backwards compatibility
33 33 warnings.warn('model should be in form of module.model:User '
34 34 'and not module.model.User', exceptions.MigrateDeprecationWarning)
35 35 dotted_name = ':'.join(dotted_name.rsplit('.', 1))
36 36 return EntryPoint.parse('x=%s' % dotted_name).load(False)
37 37 else:
38 38 # Assume it's already loaded.
39 39 return dotted_name
40 40
41 41 def asbool(obj):
42 42 """Do everything to use object as bool"""
43 43 if isinstance(obj, str):
44 44 obj = obj.strip().lower()
45 45 if obj in ['true', 'yes', 'on', 'y', 't', '1']:
46 46 return True
47 47 elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
48 48 return False
49 49 else:
50 50 raise ValueError("String is not true/false: %r" % obj)
51 51 if obj in (True, False):
52 52 return bool(obj)
53 53 else:
54 54 raise ValueError("String is not true/false: %r" % obj)
55 55
56 56 def guess_obj_type(obj):
57 57 """Do everything to guess object type from string
58 58
59 59 Tries to convert to `int`, `bool` and finally returns if not succeded.
60 60
61 61 .. versionadded: 0.5.4
62 62 """
63 63
64 64 result = None
65 65
66 66 try:
67 67 result = int(obj)
68 68 except:
69 69 pass
70 70
71 71 if result is None:
72 72 try:
73 73 result = asbool(obj)
74 74 except:
75 75 pass
76 76
77 77 if result is not None:
78 78 return result
79 79 else:
80 80 return obj
81 81
82 82 @decorator
83 83 def catch_known_errors(f, *a, **kw):
84 84 """Decorator that catches known api errors
85 85
86 86 .. versionadded: 0.5.4
87 87 """
88 88
89 89 try:
90 90 return f(*a, **kw)
91 91 except exceptions.PathFoundError as e:
92 92 raise exceptions.KnownError("The path %s already exists" % e.args[0])
93 93
94 94 def construct_engine(engine, **opts):
95 95 """.. versionadded:: 0.5.4
96 96
97 97 Constructs and returns SQLAlchemy engine.
98 98
99 99 Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions:
100 100
101 101 :param engine: connection string or a existing engine
102 102 :param engine_dict: python dictionary of options to pass to `create_engine`
103 103 :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`)
104 104 :type engine_dict: dict
105 105 :type engine: string or Engine instance
106 106 :type engine_arg_*: string
107 107 :returns: SQLAlchemy Engine
108 108
109 109 .. note::
110 110
111 111 keyword parameters override ``engine_dict`` values.
112 112
113 113 """
114 114 if isinstance(engine, Engine):
115 115 return engine
116 116 elif not isinstance(engine, str):
117 117 raise ValueError("you need to pass either an existing engine or a database uri")
118 118
119 119 # get options for create_engine
120 120 if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
121 121 kwargs = opts['engine_dict']
122 122 else:
123 123 kwargs = {}
124 124
125 125 # DEPRECATED: handle echo the old way
126 126 echo = asbool(opts.get('echo', False))
127 127 if echo:
128 128 warnings.warn('echo=True parameter is deprecated, pass '
129 129 'engine_arg_echo=True or engine_dict={"echo": True}',
130 130 exceptions.MigrateDeprecationWarning)
131 131 kwargs['echo'] = echo
132 132
133 133 # parse keyword arguments
134 for key, value in opts.items():
134 for key, value in list(opts.items()):
135 135 if key.startswith('engine_arg_'):
136 136 kwargs[key[11:]] = guess_obj_type(value)
137 137
138 138 log.debug('Constructing engine')
139 139 # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs)
140 140 # seems like 0.5.x branch does not work with engine.dispose and staticpool
141 141 return create_engine(engine, **kwargs)
142 142
143 143 @decorator
144 144 def with_engine(f, *a, **kw):
145 145 """Decorator for :mod:`migrate.versioning.api` functions
146 146 to safely close resources after function usage.
147 147
148 148 Passes engine parameters to :func:`construct_engine` and
149 149 resulting parameter is available as kw['engine'].
150 150
151 151 Engine is disposed after wrapped function is executed.
152 152
153 153 .. versionadded: 0.6.0
154 154 """
155 155 url = a[0]
156 156 engine = construct_engine(url, **kw)
157 157
158 158 try:
159 159 kw['engine'] = engine
160 160 return f(*a, **kw)
161 161 finally:
162 162 if isinstance(engine, Engine) and engine is not url:
163 163 log.debug('Disposing SQLAlchemy engine %s', engine)
164 164 engine.dispose()
165 165
166 166
167 167 class Memoize:
168 168 """Memoize(fn) - an instance which acts like fn but memoizes its arguments
169 169 Will only work on functions with non-mutable arguments
170 170
171 171 ActiveState Code 52201
172 172 """
173 173 def __init__(self, fn):
174 174 self.fn = fn
175 175 self.memo = {}
176 176
177 177 def __call__(self, *args):
178 178 if args not in self.memo:
179 179 self.memo[args] = self.fn(*args)
180 180 return self.memo[args]
@@ -1,15 +1,16 b''
1 1 import os
2 2 import sys
3 import importlib
3 4
4 5 def import_path(fullpath):
5 6 """ Import a file with full path specification. Allows one to
6 7 import from anywhere, something __import__ does not do.
7 8 """
8 9 # http://zephyrfalcon.org/weblog/arch_d7_2002_08_31.html
9 10 path, filename = os.path.split(fullpath)
10 11 filename, ext = os.path.splitext(filename)
11 12 sys.path.append(path)
12 13 module = __import__(filename)
13 reload(module) # Might be out of date during tests
14 importlib.reload(module) # Might be out of date during tests
14 15 del sys.path[-1]
15 16 return module
@@ -1,263 +1,263 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3
4 4 import os
5 5 import re
6 6 import shutil
7 7 import logging
8 8
9 9 from rhodecode.lib.dbmigrate.migrate import exceptions
10 10 from rhodecode.lib.dbmigrate.migrate.versioning import pathed, script
11 11 from datetime import datetime
12 12
13 13
14 14 log = logging.getLogger(__name__)
15 15
16 16 class VerNum(object):
17 17 """A version number that behaves like a string and int at the same time"""
18 18
19 19 _instances = {}
20 20
21 21 def __new__(cls, value):
22 22 val = str(value)
23 23 if val not in cls._instances:
24 24 cls._instances[val] = super(VerNum, cls).__new__(cls)
25 25 ret = cls._instances[val]
26 26 return ret
27 27
28 28 def __init__(self,value):
29 29 self.value = str(int(value))
30 30 if self < 0:
31 31 raise ValueError("Version number cannot be negative")
32 32
33 33 def __add__(self, value):
34 34 ret = int(self) + int(value)
35 35 return VerNum(ret)
36 36
37 37 def __sub__(self, value):
38 38 return self + (int(value) * -1)
39 39
40 40 def __eq__(self, value):
41 41 return int(self) == int(value)
42 42
43 43 def __ne__(self, value):
44 44 return int(self) != int(value)
45 45
46 46 def __lt__(self, value):
47 47 return int(self) < int(value)
48 48
49 49 def __gt__(self, value):
50 50 return int(self) > int(value)
51 51
52 52 def __ge__(self, value):
53 53 return int(self) >= int(value)
54 54
55 55 def __le__(self, value):
56 56 return int(self) <= int(value)
57 57
58 58 def __repr__(self):
59 59 return "<VerNum(%s)>" % self.value
60 60
61 61 def __str__(self):
62 62 return str(self.value)
63 63
64 64 def __int__(self):
65 65 return int(self.value)
66 66
67 67
68 68 class Collection(pathed.Pathed):
69 69 """A collection of versioning scripts in a repository"""
70 70
71 71 FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
72 72
73 73 def __init__(self, path):
74 74 """Collect current version scripts in repository
75 75 and store them in self.versions
76 76 """
77 77 super(Collection, self).__init__(path)
78 78
79 79 # Create temporary list of files, allowing skipped version numbers.
80 80 files = os.listdir(path)
81 81 if '1' in files:
82 82 # deprecation
83 83 raise Exception('It looks like you have a repository in the old '
84 84 'format (with directories for each version). '
85 85 'Please convert repository before proceeding.')
86 86
87 87 tempVersions = {}
88 88 for filename in files:
89 89 match = self.FILENAME_WITH_VERSION.match(filename)
90 90 if match:
91 91 num = int(match.group(1))
92 92 tempVersions.setdefault(num, []).append(filename)
93 93 else:
94 94 pass # Must be a helper file or something, let's ignore it.
95 95
96 96 # Create the versions member where the keys
97 97 # are VerNum's and the values are Version's.
98 98 self.versions = {}
99 for num, files in tempVersions.items():
99 for num, files in list(tempVersions.items()):
100 100 self.versions[VerNum(num)] = Version(num, path, files)
101 101
102 102 @property
103 103 def latest(self):
104 104 """:returns: Latest version in Collection"""
105 return max([VerNum(0)] + self.versions.keys())
105 return max([VerNum(0)] + list(self.versions.keys()))
106 106
107 107 def _next_ver_num(self, use_timestamp_numbering):
108 108 if use_timestamp_numbering == True:
109 109 return VerNum(int(datetime.utcnow().strftime('%Y%m%d%H%M%S')))
110 110 else:
111 111 return self.latest + 1
112 112
113 113 def create_new_python_version(self, description, **k):
114 114 """Create Python files for new version"""
115 115 ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
116 116 extra = str_to_filename(description)
117 117
118 118 if extra:
119 119 if extra == '_':
120 120 extra = ''
121 121 elif not extra.startswith('_'):
122 122 extra = '_%s' % extra
123 123
124 124 filename = '%03d%s.py' % (ver, extra)
125 125 filepath = self._version_path(filename)
126 126
127 127 script.PythonScript.create(filepath, **k)
128 128 self.versions[ver] = Version(ver, self.path, [filename])
129 129
130 130 def create_new_sql_version(self, database, description, **k):
131 131 """Create SQL files for new version"""
132 132 ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
133 133 self.versions[ver] = Version(ver, self.path, [])
134 134
135 135 extra = str_to_filename(description)
136 136
137 137 if extra:
138 138 if extra == '_':
139 139 extra = ''
140 140 elif not extra.startswith('_'):
141 141 extra = '_%s' % extra
142 142
143 143 # Create new files.
144 144 for op in ('upgrade', 'downgrade'):
145 145 filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op)
146 146 filepath = self._version_path(filename)
147 147 script.SqlScript.create(filepath, **k)
148 148 self.versions[ver].add_script(filepath)
149 149
150 150 def version(self, vernum=None):
151 151 """Returns latest Version if vernum is not given.
152 152 Otherwise, returns wanted version"""
153 153 if vernum is None:
154 154 vernum = self.latest
155 155 return self.versions[VerNum(vernum)]
156 156
157 157 @classmethod
158 158 def clear(cls):
159 159 super(Collection, cls).clear()
160 160
161 161 def _version_path(self, ver):
162 162 """Returns path of file in versions repository"""
163 163 return os.path.join(self.path, str(ver))
164 164
165 165
166 166 class Version(object):
167 167 """A single version in a collection
168 168 :param vernum: Version Number
169 169 :param path: Path to script files
170 170 :param filelist: List of scripts
171 171 :type vernum: int, VerNum
172 172 :type path: string
173 173 :type filelist: list
174 174 """
175 175
176 176 def __init__(self, vernum, path, filelist):
177 177 self.version = VerNum(vernum)
178 178
179 179 # Collect scripts in this folder
180 180 self.sql = {}
181 181 self.python = None
182 182
183 183 for script in filelist:
184 184 self.add_script(os.path.join(path, script))
185 185
186 186 def script(self, database=None, operation=None):
187 187 """Returns SQL or Python Script"""
188 188 for db in (database, 'default'):
189 189 # Try to return a .sql script first
190 190 try:
191 191 return self.sql[db][operation]
192 192 except KeyError:
193 193 continue # No .sql script exists
194 194
195 195 # TODO: maybe add force Python parameter?
196 196 ret = self.python
197 197
198 198 assert ret is not None, \
199 199 "There is no script for %d version" % self.version
200 200 return ret
201 201
202 202 def add_script(self, path):
203 203 """Add script to Collection/Version"""
204 204 if path.endswith(Extensions.py):
205 205 self._add_script_py(path)
206 206 elif path.endswith(Extensions.sql):
207 207 self._add_script_sql(path)
208 208
209 209 SQL_FILENAME = re.compile(r'^.*\.sql')
210 210
211 211 def _add_script_sql(self, path):
212 212 basename = os.path.basename(path)
213 213 match = self.SQL_FILENAME.match(basename)
214 214
215 215 if match:
216 216 basename = basename.replace('.sql', '')
217 217 parts = basename.split('_')
218 218 if len(parts) < 3:
219 219 raise exceptions.ScriptError(
220 220 "Invalid SQL script name %s " % basename + \
221 221 "(needs to be ###_description_database_operation.sql)")
222 222 version = parts[0]
223 223 op = parts[-1]
224 224 # NOTE(mriedem): check for ibm_db_sa as the database in the name
225 225 if 'ibm_db_sa' in basename:
226 226 if len(parts) == 6:
227 227 dbms = '_'.join(parts[-4: -1])
228 228 else:
229 229 raise exceptions.ScriptError(
230 230 "Invalid ibm_db_sa SQL script name '%s'; "
231 231 "(needs to be "
232 232 "###_description_ibm_db_sa_operation.sql)" % basename)
233 233 else:
234 234 dbms = parts[-2]
235 235 else:
236 236 raise exceptions.ScriptError(
237 237 "Invalid SQL script name %s " % basename + \
238 238 "(needs to be ###_description_database_operation.sql)")
239 239
240 240 # File the script into a dictionary
241 241 self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
242 242
243 243 def _add_script_py(self, path):
244 244 if self.python is not None:
245 245 raise exceptions.ScriptError('You can only have one Python script '
246 246 'per version, but you have: %s and %s' % (self.python, path))
247 247 self.python = script.PythonScript(path)
248 248
249 249
250 250 class Extensions:
251 251 """A namespace for file extensions"""
252 252 py = 'py'
253 253 sql = 'sql'
254 254
255 255 def str_to_filename(s):
256 256 """Replaces spaces, (double and single) quotes
257 257 and double underscores to underscores
258 258 """
259 259
260 260 s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
261 261 while '__' in s:
262 262 s = s.replace('__', '_')
263 263 return s
General Comments 0
You need to be logged in to leave comments. Login now