##// END OF EJS Templates
dbmigrate: 2to3 pass with fixes
super-admin -
r4988:ff25f201 default
parent child Browse files
Show More
@@ -1,314 +1,314 b''
1 """
1 """
2 Extensions to SQLAlchemy for altering existing tables.
2 Extensions to SQLAlchemy for altering existing tables.
3
3
4 At the moment, this isn't so much based off of ANSI as much as
4 At the moment, this isn't so much based off of ANSI as much as
5 things that just happen to work with multiple databases.
5 things that just happen to work with multiple databases.
6 """
6 """
7 from io import StringIO
7 import io
8
8
9 import sqlalchemy as sa
9 import sqlalchemy as sa
10 from sqlalchemy.schema import SchemaVisitor
10 from sqlalchemy.schema import SchemaVisitor
11 from sqlalchemy.engine.default import DefaultDialect
11 from sqlalchemy.engine.default import DefaultDialect
12 from sqlalchemy.sql import ClauseElement
12 from sqlalchemy.sql import ClauseElement
13 from sqlalchemy.schema import (ForeignKeyConstraint,
13 from sqlalchemy.schema import (ForeignKeyConstraint,
14 PrimaryKeyConstraint,
14 PrimaryKeyConstraint,
15 CheckConstraint,
15 CheckConstraint,
16 UniqueConstraint,
16 UniqueConstraint,
17 Index)
17 Index)
18
18
19 import sqlalchemy.sql.compiler
19 import sqlalchemy.sql.compiler
20 from rhodecode.lib.dbmigrate.migrate import exceptions
20 from rhodecode.lib.dbmigrate.migrate import exceptions
21 from rhodecode.lib.dbmigrate.migrate.changeset import constraint
21 from rhodecode.lib.dbmigrate.migrate.changeset import constraint
22 from rhodecode.lib.dbmigrate.migrate.changeset import util
22 from rhodecode.lib.dbmigrate.migrate.changeset import util
23
23
24 from sqlalchemy.schema import AddConstraint, DropConstraint
24 from sqlalchemy.schema import AddConstraint, DropConstraint
25 from sqlalchemy.sql.compiler import DDLCompiler
25 from sqlalchemy.sql.compiler import DDLCompiler
26 SchemaGenerator = SchemaDropper = DDLCompiler
26 SchemaGenerator = SchemaDropper = DDLCompiler
27
27
28
28
29 class AlterTableVisitor(SchemaVisitor):
29 class AlterTableVisitor(SchemaVisitor):
30 """Common operations for ``ALTER TABLE`` statements."""
30 """Common operations for ``ALTER TABLE`` statements."""
31
31
32 # engine.Compiler looks for .statement
32 # engine.Compiler looks for .statement
33 # when it spawns off a new compiler
33 # when it spawns off a new compiler
34 statement = ClauseElement()
34 statement = ClauseElement()
35
35
36 def append(self, s):
36 def append(self, s):
37 """Append content to the SchemaIterator's query buffer."""
37 """Append content to the SchemaIterator's query buffer."""
38
38
39 self.buffer.write(s)
39 self.buffer.write(s)
40
40
41 def execute(self):
41 def execute(self):
42 """Execute the contents of the SchemaIterator's buffer."""
42 """Execute the contents of the SchemaIterator's buffer."""
43 try:
43 try:
44 return self.connection.execute(self.buffer.getvalue())
44 return self.connection.execute(self.buffer.getvalue())
45 finally:
45 finally:
46 self.buffer.seek(0)
46 self.buffer.seek(0)
47 self.buffer.truncate()
47 self.buffer.truncate()
48
48
49 def __init__(self, dialect, connection, **kw):
49 def __init__(self, dialect, connection, **kw):
50 self.connection = connection
50 self.connection = connection
51 self.buffer = StringIO.StringIO()
51 self.buffer = io.StringIO()
52 self.preparer = dialect.identifier_preparer
52 self.preparer = dialect.identifier_preparer
53 self.dialect = dialect
53 self.dialect = dialect
54
54
55 def traverse_single(self, elem):
55 def traverse_single(self, elem):
56 ret = super(AlterTableVisitor, self).traverse_single(elem)
56 ret = super(AlterTableVisitor, self).traverse_single(elem)
57 if ret:
57 if ret:
58 # adapt to 0.6 which uses a string-returning
58 # adapt to 0.6 which uses a string-returning
59 # object
59 # object
60 self.append(" %s" % ret)
60 self.append(" %s" % ret)
61
61
62 def _to_table(self, param):
62 def _to_table(self, param):
63 """Returns the table object for the given param object."""
63 """Returns the table object for the given param object."""
64 if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
64 if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
65 ret = param.table
65 ret = param.table
66 else:
66 else:
67 ret = param
67 ret = param
68 return ret
68 return ret
69
69
70 def start_alter_table(self, param):
70 def start_alter_table(self, param):
71 """Returns the start of an ``ALTER TABLE`` SQL-Statement.
71 """Returns the start of an ``ALTER TABLE`` SQL-Statement.
72
72
73 Use the param object to determine the table name and use it
73 Use the param object to determine the table name and use it
74 for building the SQL statement.
74 for building the SQL statement.
75
75
76 :param param: object to determine the table from
76 :param param: object to determine the table from
77 :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
77 :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
78 :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
78 :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
79 or string (table name)
79 or string (table name)
80 """
80 """
81 table = self._to_table(param)
81 table = self._to_table(param)
82 self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
82 self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
83 return table
83 return table
84
84
85
85
86 class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
86 class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
87 """Extends ansisql generator for column creation (alter table add col)"""
87 """Extends ansisql generator for column creation (alter table add col)"""
88
88
89 def visit_column(self, column):
89 def visit_column(self, column):
90 """Create a column (table already exists).
90 """Create a column (table already exists).
91
91
92 :param column: column object
92 :param column: column object
93 :type column: :class:`sqlalchemy.Column` instance
93 :type column: :class:`sqlalchemy.Column` instance
94 """
94 """
95 if column.default is not None:
95 if column.default is not None:
96 self.traverse_single(column.default)
96 self.traverse_single(column.default)
97
97
98 table = self.start_alter_table(column)
98 table = self.start_alter_table(column)
99 self.append("ADD ")
99 self.append("ADD ")
100 self.append(self.get_column_specification(column))
100 self.append(self.get_column_specification(column))
101
101
102 for cons in column.constraints:
102 for cons in column.constraints:
103 self.traverse_single(cons)
103 self.traverse_single(cons)
104 self.execute()
104 self.execute()
105
105
106 # ALTER TABLE STATEMENTS
106 # ALTER TABLE STATEMENTS
107
107
108 # add indexes and unique constraints
108 # add indexes and unique constraints
109 if column.index_name:
109 if column.index_name:
110 Index(column.index_name,column).create()
110 Index(column.index_name,column).create()
111 elif column.unique_name:
111 elif column.unique_name:
112 constraint.UniqueConstraint(column,
112 constraint.UniqueConstraint(column,
113 name=column.unique_name).create()
113 name=column.unique_name).create()
114
114
115 # SA bounds FK constraints to table, add manually
115 # SA bounds FK constraints to table, add manually
116 for fk in column.foreign_keys:
116 for fk in column.foreign_keys:
117 self.add_foreignkey(fk.constraint)
117 self.add_foreignkey(fk.constraint)
118
118
119 # add primary key constraint if needed
119 # add primary key constraint if needed
120 if column.primary_key_name:
120 if column.primary_key_name:
121 cons = constraint.PrimaryKeyConstraint(column,
121 cons = constraint.PrimaryKeyConstraint(column,
122 name=column.primary_key_name)
122 name=column.primary_key_name)
123 cons.create()
123 cons.create()
124
124
125 def add_foreignkey(self, fk):
125 def add_foreignkey(self, fk):
126 self.connection.execute(AddConstraint(fk))
126 self.connection.execute(AddConstraint(fk))
127
127
128 class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
128 class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
129 """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
129 """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
130 DROP COLUMN``).
130 DROP COLUMN``).
131 """
131 """
132
132
133 def visit_column(self, column):
133 def visit_column(self, column):
134 """Drop a column from its table.
134 """Drop a column from its table.
135
135
136 :param column: the column object
136 :param column: the column object
137 :type column: :class:`sqlalchemy.Column`
137 :type column: :class:`sqlalchemy.Column`
138 """
138 """
139 table = self.start_alter_table(column)
139 table = self.start_alter_table(column)
140 self.append('DROP COLUMN %s' % self.preparer.format_column(column))
140 self.append('DROP COLUMN %s' % self.preparer.format_column(column))
141 self.execute()
141 self.execute()
142
142
143
143
144 class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
144 class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
145 """Manages changes to existing schema elements.
145 """Manages changes to existing schema elements.
146
146
147 Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
147 Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
148 is in SchemaGenerator.
148 is in SchemaGenerator.
149
149
150 All items may be renamed. Columns can also have many of their properties -
150 All items may be renamed. Columns can also have many of their properties -
151 type, for example - changed.
151 type, for example - changed.
152
152
153 Each function is passed a tuple, containing (object, name); where
153 Each function is passed a tuple, containing (object, name); where
154 object is a type of object you'd expect for that function
154 object is a type of object you'd expect for that function
155 (ie. table for visit_table) and name is the object's new
155 (ie. table for visit_table) and name is the object's new
156 name. NONE means the name is unchanged.
156 name. NONE means the name is unchanged.
157 """
157 """
158
158
159 def visit_table(self, table):
159 def visit_table(self, table):
160 """Rename a table. Other ops aren't supported."""
160 """Rename a table. Other ops aren't supported."""
161 self.start_alter_table(table)
161 self.start_alter_table(table)
162 q = util.safe_quote(table)
162 q = util.safe_quote(table)
163 self.append("RENAME TO %s" % self.preparer.quote(table.new_name, q))
163 self.append("RENAME TO %s" % self.preparer.quote(table.new_name, q))
164 self.execute()
164 self.execute()
165
165
166 def visit_index(self, index):
166 def visit_index(self, index):
167 """Rename an index"""
167 """Rename an index"""
168 if hasattr(self, '_validate_identifier'):
168 if hasattr(self, '_validate_identifier'):
169 # SA <= 0.6.3
169 # SA <= 0.6.3
170 self.append("ALTER INDEX %s RENAME TO %s" % (
170 self.append("ALTER INDEX %s RENAME TO %s" % (
171 self.preparer.quote(
171 self.preparer.quote(
172 self._validate_identifier(
172 self._validate_identifier(
173 index.name, True), index.quote),
173 index.name, True), index.quote),
174 self.preparer.quote(
174 self.preparer.quote(
175 self._validate_identifier(
175 self._validate_identifier(
176 index.new_name, True), index.quote)))
176 index.new_name, True), index.quote)))
177 elif hasattr(self, '_index_identifier'):
177 elif hasattr(self, '_index_identifier'):
178 # SA >= 0.6.5, < 0.8
178 # SA >= 0.6.5, < 0.8
179 self.append("ALTER INDEX %s RENAME TO %s" % (
179 self.append("ALTER INDEX %s RENAME TO %s" % (
180 self.preparer.quote(
180 self.preparer.quote(
181 self._index_identifier(
181 self._index_identifier(
182 index.name), index.quote),
182 index.name), index.quote),
183 self.preparer.quote(
183 self.preparer.quote(
184 self._index_identifier(
184 self._index_identifier(
185 index.new_name), index.quote)))
185 index.new_name), index.quote)))
186 else:
186 else:
187 # SA >= 0.8
187 # SA >= 0.8
188 class NewName(object):
188 class NewName(object):
189 """Map obj.name -> obj.new_name"""
189 """Map obj.name -> obj.new_name"""
190 def __init__(self, index):
190 def __init__(self, index):
191 self.name = index.new_name
191 self.name = index.new_name
192 self._obj = index
192 self._obj = index
193
193
194 def __getattr__(self, attr):
194 def __getattr__(self, attr):
195 if attr == 'name':
195 if attr == 'name':
196 return getattr(self, attr)
196 return getattr(self, attr)
197 return getattr(self._obj, attr)
197 return getattr(self._obj, attr)
198
198
199 self.append("ALTER INDEX %s RENAME TO %s" % (
199 self.append("ALTER INDEX %s RENAME TO %s" % (
200 self._prepared_index_name(index),
200 self._prepared_index_name(index),
201 self._prepared_index_name(NewName(index))))
201 self._prepared_index_name(NewName(index))))
202
202
203 self.execute()
203 self.execute()
204
204
205 def visit_column(self, delta):
205 def visit_column(self, delta):
206 """Rename/change a column."""
206 """Rename/change a column."""
207 # ALTER COLUMN is implemented as several ALTER statements
207 # ALTER COLUMN is implemented as several ALTER statements
208 keys = delta.keys()
208 keys = list(delta.keys())
209 if 'type' in keys:
209 if 'type' in keys:
210 self._run_subvisit(delta, self._visit_column_type)
210 self._run_subvisit(delta, self._visit_column_type)
211 if 'nullable' in keys:
211 if 'nullable' in keys:
212 self._run_subvisit(delta, self._visit_column_nullable)
212 self._run_subvisit(delta, self._visit_column_nullable)
213 if 'server_default' in keys:
213 if 'server_default' in keys:
214 # Skip 'default': only handle server-side defaults, others
214 # Skip 'default': only handle server-side defaults, others
215 # are managed by the app, not the db.
215 # are managed by the app, not the db.
216 self._run_subvisit(delta, self._visit_column_default)
216 self._run_subvisit(delta, self._visit_column_default)
217 if 'name' in keys:
217 if 'name' in keys:
218 self._run_subvisit(delta, self._visit_column_name, start_alter=False)
218 self._run_subvisit(delta, self._visit_column_name, start_alter=False)
219
219
220 def _run_subvisit(self, delta, func, start_alter=True):
220 def _run_subvisit(self, delta, func, start_alter=True):
221 """Runs visit method based on what needs to be changed on column"""
221 """Runs visit method based on what needs to be changed on column"""
222 table = self._to_table(delta.table)
222 table = self._to_table(delta.table)
223 col_name = delta.current_name
223 col_name = delta.current_name
224 if start_alter:
224 if start_alter:
225 self.start_alter_column(table, col_name)
225 self.start_alter_column(table, col_name)
226 ret = func(table, delta.result_column, delta)
226 ret = func(table, delta.result_column, delta)
227 self.execute()
227 self.execute()
228
228
229 def start_alter_column(self, table, col_name):
229 def start_alter_column(self, table, col_name):
230 """Starts ALTER COLUMN"""
230 """Starts ALTER COLUMN"""
231 self.start_alter_table(table)
231 self.start_alter_table(table)
232 q = util.safe_quote(table)
232 q = util.safe_quote(table)
233 self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, q))
233 self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, q))
234
234
235 def _visit_column_nullable(self, table, column, delta):
235 def _visit_column_nullable(self, table, column, delta):
236 nullable = delta['nullable']
236 nullable = delta['nullable']
237 if nullable:
237 if nullable:
238 self.append("DROP NOT NULL")
238 self.append("DROP NOT NULL")
239 else:
239 else:
240 self.append("SET NOT NULL")
240 self.append("SET NOT NULL")
241
241
242 def _visit_column_default(self, table, column, delta):
242 def _visit_column_default(self, table, column, delta):
243 default_text = self.get_column_default_string(column)
243 default_text = self.get_column_default_string(column)
244 if default_text is not None:
244 if default_text is not None:
245 self.append("SET DEFAULT %s" % default_text)
245 self.append("SET DEFAULT %s" % default_text)
246 else:
246 else:
247 self.append("DROP DEFAULT")
247 self.append("DROP DEFAULT")
248
248
249 def _visit_column_type(self, table, column, delta):
249 def _visit_column_type(self, table, column, delta):
250 type_ = delta['type']
250 type_ = delta['type']
251 type_text = str(type_.compile(dialect=self.dialect))
251 type_text = str(type_.compile(dialect=self.dialect))
252 self.append("TYPE %s" % type_text)
252 self.append("TYPE %s" % type_text)
253
253
254 def _visit_column_name(self, table, column, delta):
254 def _visit_column_name(self, table, column, delta):
255 self.start_alter_table(table)
255 self.start_alter_table(table)
256 q = util.safe_quote(table)
256 q = util.safe_quote(table)
257 col_name = self.preparer.quote(delta.current_name, q)
257 col_name = self.preparer.quote(delta.current_name, q)
258 new_name = self.preparer.format_column(delta.result_column)
258 new_name = self.preparer.format_column(delta.result_column)
259 self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
259 self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
260
260
261
261
262 class ANSIConstraintCommon(AlterTableVisitor):
262 class ANSIConstraintCommon(AlterTableVisitor):
263 """
263 """
264 Migrate's constraints require a separate creation function from
264 Migrate's constraints require a separate creation function from
265 SA's: Migrate's constraints are created independently of a table;
265 SA's: Migrate's constraints are created independently of a table;
266 SA's are created at the same time as the table.
266 SA's are created at the same time as the table.
267 """
267 """
268
268
269 def get_constraint_name(self, cons):
269 def get_constraint_name(self, cons):
270 """Gets a name for the given constraint.
270 """Gets a name for the given constraint.
271
271
272 If the name is already set it will be used otherwise the
272 If the name is already set it will be used otherwise the
273 constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
273 constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
274 method is used.
274 method is used.
275
275
276 :param cons: constraint object
276 :param cons: constraint object
277 """
277 """
278 if cons.name is not None:
278 if cons.name is not None:
279 ret = cons.name
279 ret = cons.name
280 else:
280 else:
281 ret = cons.name = cons.autoname()
281 ret = cons.name = cons.autoname()
282 return self.preparer.quote(ret, cons.quote)
282 return self.preparer.quote(ret, cons.quote)
283
283
284 def visit_migrate_primary_key_constraint(self, *p, **k):
284 def visit_migrate_primary_key_constraint(self, *p, **k):
285 self._visit_constraint(*p, **k)
285 self._visit_constraint(*p, **k)
286
286
287 def visit_migrate_foreign_key_constraint(self, *p, **k):
287 def visit_migrate_foreign_key_constraint(self, *p, **k):
288 self._visit_constraint(*p, **k)
288 self._visit_constraint(*p, **k)
289
289
290 def visit_migrate_check_constraint(self, *p, **k):
290 def visit_migrate_check_constraint(self, *p, **k):
291 self._visit_constraint(*p, **k)
291 self._visit_constraint(*p, **k)
292
292
293 def visit_migrate_unique_constraint(self, *p, **k):
293 def visit_migrate_unique_constraint(self, *p, **k):
294 self._visit_constraint(*p, **k)
294 self._visit_constraint(*p, **k)
295
295
296 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
296 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
297 def _visit_constraint(self, constraint):
297 def _visit_constraint(self, constraint):
298 constraint.name = self.get_constraint_name(constraint)
298 constraint.name = self.get_constraint_name(constraint)
299 self.append(self.process(AddConstraint(constraint)))
299 self.append(self.process(AddConstraint(constraint)))
300 self.execute()
300 self.execute()
301
301
302 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
302 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
303 def _visit_constraint(self, constraint):
303 def _visit_constraint(self, constraint):
304 constraint.name = self.get_constraint_name(constraint)
304 constraint.name = self.get_constraint_name(constraint)
305 self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
305 self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
306 self.execute()
306 self.execute()
307
307
308
308
309 class ANSIDialect(DefaultDialect):
309 class ANSIDialect(DefaultDialect):
310 columngenerator = ANSIColumnGenerator
310 columngenerator = ANSIColumnGenerator
311 columndropper = ANSIColumnDropper
311 columndropper = ANSIColumnDropper
312 schemachanger = ANSISchemaChanger
312 schemachanger = ANSISchemaChanger
313 constraintgenerator = ANSIConstraintGenerator
313 constraintgenerator = ANSIConstraintGenerator
314 constraintdropper = ANSIConstraintDropper
314 constraintdropper = ANSIConstraintDropper
@@ -1,200 +1,200 b''
1 """
1 """
2 This module defines standalone schema constraint classes.
2 This module defines standalone schema constraint classes.
3 """
3 """
4 from sqlalchemy import schema
4 from sqlalchemy import schema
5
5
6 from rhodecode.lib.dbmigrate.migrate.exceptions import *
6 from rhodecode.lib.dbmigrate.migrate.exceptions import *
7
7
8
8
9 class ConstraintChangeset(object):
9 class ConstraintChangeset(object):
10 """Base class for Constraint classes."""
10 """Base class for Constraint classes."""
11
11
12 def _normalize_columns(self, cols, table_name=False):
12 def _normalize_columns(self, cols, table_name=False):
13 """Given: column objects or names; return col names and
13 """Given: column objects or names; return col names and
14 (maybe) a table"""
14 (maybe) a table"""
15 colnames = []
15 colnames = []
16 table = None
16 table = None
17 for col in cols:
17 for col in cols:
18 if isinstance(col, schema.Column):
18 if isinstance(col, schema.Column):
19 if col.table is not None and table is None:
19 if col.table is not None and table is None:
20 table = col.table
20 table = col.table
21 if table_name:
21 if table_name:
22 col = '.'.join((col.table.name, col.name))
22 col = '.'.join((col.table.name, col.name))
23 else:
23 else:
24 col = col.name
24 col = col.name
25 colnames.append(col)
25 colnames.append(col)
26 return colnames, table
26 return colnames, table
27
27
28 def __do_imports(self, visitor_name, *a, **kw):
28 def __do_imports(self, visitor_name, *a, **kw):
29 engine = kw.pop('engine', self.table.bind)
29 engine = kw.pop('engine', self.table.bind)
30 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (
30 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (
31 get_engine_visitor, run_single_visitor)
31 get_engine_visitor, run_single_visitor)
32 visitorcallable = get_engine_visitor(engine, visitor_name)
32 visitorcallable = get_engine_visitor(engine, visitor_name)
33 run_single_visitor(engine, visitorcallable, self, *a, **kw)
33 run_single_visitor(engine, visitorcallable, self, *a, **kw)
34
34
35 def create(self, *a, **kw):
35 def create(self, *a, **kw):
36 """Create the constraint in the database.
36 """Create the constraint in the database.
37
37
38 :param engine: the database engine to use. If this is \
38 :param engine: the database engine to use. If this is \
39 :keyword:`None` the instance's engine will be used
39 :keyword:`None` the instance's engine will be used
40 :type engine: :class:`sqlalchemy.engine.base.Engine`
40 :type engine: :class:`sqlalchemy.engine.base.Engine`
41 :param connection: reuse connection istead of creating new one.
41 :param connection: reuse connection istead of creating new one.
42 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
42 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
43 """
43 """
44 # TODO: set the parent here instead of in __init__
44 # TODO: set the parent here instead of in __init__
45 self.__do_imports('constraintgenerator', *a, **kw)
45 self.__do_imports('constraintgenerator', *a, **kw)
46
46
47 def drop(self, *a, **kw):
47 def drop(self, *a, **kw):
48 """Drop the constraint from the database.
48 """Drop the constraint from the database.
49
49
50 :param engine: the database engine to use. If this is
50 :param engine: the database engine to use. If this is
51 :keyword:`None` the instance's engine will be used
51 :keyword:`None` the instance's engine will be used
52 :param cascade: Issue CASCADE drop if database supports it
52 :param cascade: Issue CASCADE drop if database supports it
53 :type engine: :class:`sqlalchemy.engine.base.Engine`
53 :type engine: :class:`sqlalchemy.engine.base.Engine`
54 :type cascade: bool
54 :type cascade: bool
55 :param connection: reuse connection istead of creating new one.
55 :param connection: reuse connection istead of creating new one.
56 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
56 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
57 :returns: Instance with cleared columns
57 :returns: Instance with cleared columns
58 """
58 """
59 self.cascade = kw.pop('cascade', False)
59 self.cascade = kw.pop('cascade', False)
60 self.__do_imports('constraintdropper', *a, **kw)
60 self.__do_imports('constraintdropper', *a, **kw)
61 # the spirit of Constraint objects is that they
61 # the spirit of Constraint objects is that they
62 # are immutable (just like in a DB. they're only ADDed
62 # are immutable (just like in a DB. they're only ADDed
63 # or DROPped).
63 # or DROPped).
64 #self.columns.clear()
64 #self.columns.clear()
65 return self
65 return self
66
66
67
67
68 class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
68 class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
69 """Construct PrimaryKeyConstraint
69 """Construct PrimaryKeyConstraint
70
70
71 Migrate's additional parameters:
71 Migrate's additional parameters:
72
72
73 :param cols: Columns in constraint.
73 :param cols: Columns in constraint.
74 :param table: If columns are passed as strings, this kw is required
74 :param table: If columns are passed as strings, this kw is required
75 :type table: Table instance
75 :type table: Table instance
76 :type cols: strings or Column instances
76 :type cols: strings or Column instances
77 """
77 """
78
78
79 __migrate_visit_name__ = 'migrate_primary_key_constraint'
79 __migrate_visit_name__ = 'migrate_primary_key_constraint'
80
80
81 def __init__(self, *cols, **kwargs):
81 def __init__(self, *cols, **kwargs):
82 colnames, table = self._normalize_columns(cols)
82 colnames, table = self._normalize_columns(cols)
83 table = kwargs.pop('table', table)
83 table = kwargs.pop('table', table)
84 super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs)
84 super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs)
85 if table is not None:
85 if table is not None:
86 self._set_parent(table)
86 self._set_parent(table)
87
87
88 def autoname(self):
88 def autoname(self):
89 """Mimic the database's automatic constraint names"""
89 """Mimic the database's automatic constraint names"""
90 return "%s_pkey" % self.table.name
90 return "%s_pkey" % self.table.name
91
91
92
92
93 class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
93 class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
94 """Construct ForeignKeyConstraint
94 """Construct ForeignKeyConstraint
95
95
96 Migrate's additional parameters:
96 Migrate's additional parameters:
97
97
98 :param columns: Columns in constraint
98 :param columns: Columns in constraint
99 :param refcolumns: Columns that this FK reffers to in another table.
99 :param refcolumns: Columns that this FK reffers to in another table.
100 :param table: If columns are passed as strings, this kw is required
100 :param table: If columns are passed as strings, this kw is required
101 :type table: Table instance
101 :type table: Table instance
102 :type columns: list of strings or Column instances
102 :type columns: list of strings or Column instances
103 :type refcolumns: list of strings or Column instances
103 :type refcolumns: list of strings or Column instances
104 """
104 """
105
105
106 __migrate_visit_name__ = 'migrate_foreign_key_constraint'
106 __migrate_visit_name__ = 'migrate_foreign_key_constraint'
107
107
108 def __init__(self, columns, refcolumns, *args, **kwargs):
108 def __init__(self, columns, refcolumns, *args, **kwargs):
109 colnames, table = self._normalize_columns(columns)
109 colnames, table = self._normalize_columns(columns)
110 table = kwargs.pop('table', table)
110 table = kwargs.pop('table', table)
111 refcolnames, reftable = self._normalize_columns(refcolumns,
111 refcolnames, reftable = self._normalize_columns(refcolumns,
112 table_name=True)
112 table_name=True)
113 super(ForeignKeyConstraint, self).__init__(
113 super(ForeignKeyConstraint, self).__init__(
114 colnames, refcolnames, *args, **kwargs
114 colnames, refcolnames, *args, **kwargs
115 )
115 )
116 if table is not None:
116 if table is not None:
117 self._set_parent(table)
117 self._set_parent(table)
118
118
119 @property
119 @property
120 def referenced(self):
120 def referenced(self):
121 return [e.column for e in self.elements]
121 return [e.column for e in self.elements]
122
122
123 @property
123 @property
124 def reftable(self):
124 def reftable(self):
125 return self.referenced[0].table
125 return self.referenced[0].table
126
126
127 def autoname(self):
127 def autoname(self):
128 """Mimic the database's automatic constraint names"""
128 """Mimic the database's automatic constraint names"""
129 if hasattr(self.columns, 'keys'):
129 if hasattr(self.columns, 'keys'):
130 # SA <= 0.5
130 # SA <= 0.5
131 firstcol = self.columns[self.columns.keys()[0]]
131 firstcol = self.columns[list(self.columns.keys())[0]]
132 ret = "%(table)s_%(firstcolumn)s_fkey" % {
132 ret = "%(table)s_%(firstcolumn)s_fkey" % {
133 'table': firstcol.table.name,
133 'table': firstcol.table.name,
134 'firstcolumn': firstcol.name,}
134 'firstcolumn': firstcol.name,}
135 else:
135 else:
136 # SA >= 0.6
136 # SA >= 0.6
137 ret = "%(table)s_%(firstcolumn)s_fkey" % {
137 ret = "%(table)s_%(firstcolumn)s_fkey" % {
138 'table': self.table.name,
138 'table': self.table.name,
139 'firstcolumn': self.columns[0],}
139 'firstcolumn': self.columns[0],}
140 return ret
140 return ret
141
141
142
142
143 class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
143 class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
144 """Construct CheckConstraint
144 """Construct CheckConstraint
145
145
146 Migrate's additional parameters:
146 Migrate's additional parameters:
147
147
148 :param sqltext: Plain SQL text to check condition
148 :param sqltext: Plain SQL text to check condition
149 :param columns: If not name is applied, you must supply this kw\
149 :param columns: If not name is applied, you must supply this kw\
150 to autoname constraint
150 to autoname constraint
151 :param table: If columns are passed as strings, this kw is required
151 :param table: If columns are passed as strings, this kw is required
152 :type table: Table instance
152 :type table: Table instance
153 :type columns: list of Columns instances
153 :type columns: list of Columns instances
154 :type sqltext: string
154 :type sqltext: string
155 """
155 """
156
156
157 __migrate_visit_name__ = 'migrate_check_constraint'
157 __migrate_visit_name__ = 'migrate_check_constraint'
158
158
159 def __init__(self, sqltext, *args, **kwargs):
159 def __init__(self, sqltext, *args, **kwargs):
160 cols = kwargs.pop('columns', [])
160 cols = kwargs.pop('columns', [])
161 if not cols and not kwargs.get('name', False):
161 if not cols and not kwargs.get('name', False):
162 raise InvalidConstraintError('You must either set "name"'
162 raise InvalidConstraintError('You must either set "name"'
163 'parameter or "columns" to autogenarate it.')
163 'parameter or "columns" to autogenarate it.')
164 colnames, table = self._normalize_columns(cols)
164 colnames, table = self._normalize_columns(cols)
165 table = kwargs.pop('table', table)
165 table = kwargs.pop('table', table)
166 schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
166 schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
167 if table is not None:
167 if table is not None:
168 self._set_parent(table)
168 self._set_parent(table)
169 self.colnames = colnames
169 self.colnames = colnames
170
170
171 def autoname(self):
171 def autoname(self):
172 return "%(table)s_%(cols)s_check" % \
172 return "%(table)s_%(cols)s_check" % \
173 {'table': self.table.name, 'cols': "_".join(self.colnames)}
173 {'table': self.table.name, 'cols': "_".join(self.colnames)}
174
174
175
175
176 class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
176 class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
177 """Construct UniqueConstraint
177 """Construct UniqueConstraint
178
178
179 Migrate's additional parameters:
179 Migrate's additional parameters:
180
180
181 :param cols: Columns in constraint.
181 :param cols: Columns in constraint.
182 :param table: If columns are passed as strings, this kw is required
182 :param table: If columns are passed as strings, this kw is required
183 :type table: Table instance
183 :type table: Table instance
184 :type cols: strings or Column instances
184 :type cols: strings or Column instances
185
185
186 .. versionadded:: 0.6.0
186 .. versionadded:: 0.6.0
187 """
187 """
188
188
189 __migrate_visit_name__ = 'migrate_unique_constraint'
189 __migrate_visit_name__ = 'migrate_unique_constraint'
190
190
191 def __init__(self, *cols, **kwargs):
191 def __init__(self, *cols, **kwargs):
192 self.colnames, table = self._normalize_columns(cols)
192 self.colnames, table = self._normalize_columns(cols)
193 table = kwargs.pop('table', table)
193 table = kwargs.pop('table', table)
194 super(UniqueConstraint, self).__init__(*self.colnames, **kwargs)
194 super(UniqueConstraint, self).__init__(*self.colnames, **kwargs)
195 if table is not None:
195 if table is not None:
196 self._set_parent(table)
196 self._set_parent(table)
197
197
198 def autoname(self):
198 def autoname(self):
199 """Mimic the database's automatic constraint names"""
199 """Mimic the database's automatic constraint names"""
200 return "%s_%s_key" % (self.table.name, '_'.join(self.colnames))
200 return "%s_%s_key" % (self.table.name, '_'.join(self.colnames))
@@ -1,108 +1,108 b''
1 """
1 """
2 Oracle database specific implementations of changeset classes.
2 Oracle database specific implementations of changeset classes.
3 """
3 """
4 import sqlalchemy as sa
4 import sqlalchemy as sa
5 from sqlalchemy.databases import oracle as sa_base
5 from sqlalchemy.databases import oracle as sa_base
6
6
7 from rhodecode.lib.dbmigrate.migrate import exceptions
7 from rhodecode.lib.dbmigrate.migrate import exceptions
8 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql
8 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql
9
9
10
10
11 OracleSchemaGenerator = sa_base.OracleDDLCompiler
11 OracleSchemaGenerator = sa_base.OracleDDLCompiler
12
12
13
13
14 class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
14 class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
15 pass
15 pass
16
16
17
17
18 class OracleColumnDropper(ansisql.ANSIColumnDropper):
18 class OracleColumnDropper(ansisql.ANSIColumnDropper):
19 pass
19 pass
20
20
21
21
22 class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger):
22 class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger):
23
23
24 def get_column_specification(self, column, **kwargs):
24 def get_column_specification(self, column, **kwargs):
25 # Ignore the NOT NULL generated
25 # Ignore the NOT NULL generated
26 override_nullable = kwargs.pop('override_nullable', None)
26 override_nullable = kwargs.pop('override_nullable', None)
27 if override_nullable:
27 if override_nullable:
28 orig = column.nullable
28 orig = column.nullable
29 column.nullable = True
29 column.nullable = True
30 ret = super(OracleSchemaChanger, self).get_column_specification(
30 ret = super(OracleSchemaChanger, self).get_column_specification(
31 column, **kwargs)
31 column, **kwargs)
32 if override_nullable:
32 if override_nullable:
33 column.nullable = orig
33 column.nullable = orig
34 return ret
34 return ret
35
35
36 def visit_column(self, delta):
36 def visit_column(self, delta):
37 keys = delta.keys()
37 keys = list(delta.keys())
38
38
39 if 'name' in keys:
39 if 'name' in keys:
40 self._run_subvisit(delta,
40 self._run_subvisit(delta,
41 self._visit_column_name,
41 self._visit_column_name,
42 start_alter=False)
42 start_alter=False)
43
43
44 if len(set(('type', 'nullable', 'server_default')).intersection(keys)):
44 if len(set(('type', 'nullable', 'server_default')).intersection(keys)):
45 self._run_subvisit(delta,
45 self._run_subvisit(delta,
46 self._visit_column_change,
46 self._visit_column_change,
47 start_alter=False)
47 start_alter=False)
48
48
49 def _visit_column_change(self, table, column, delta):
49 def _visit_column_change(self, table, column, delta):
50 # Oracle cannot drop a default once created, but it can set it
50 # Oracle cannot drop a default once created, but it can set it
51 # to null. We'll do that if default=None
51 # to null. We'll do that if default=None
52 # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234
52 # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234
53 dropdefault_hack = (column.server_default is None \
53 dropdefault_hack = (column.server_default is None \
54 and 'server_default' in delta.keys())
54 and 'server_default' in list(delta.keys()))
55 # Oracle apparently doesn't like it when we say "not null" if
55 # Oracle apparently doesn't like it when we say "not null" if
56 # the column's already not null. Fudge it, so we don't need a
56 # the column's already not null. Fudge it, so we don't need a
57 # new function
57 # new function
58 notnull_hack = ((not column.nullable) \
58 notnull_hack = ((not column.nullable) \
59 and ('nullable' not in delta.keys()))
59 and ('nullable' not in list(delta.keys())))
60 # We need to specify NULL if we're removing a NOT NULL
60 # We need to specify NULL if we're removing a NOT NULL
61 # constraint
61 # constraint
62 null_hack = (column.nullable and ('nullable' in delta.keys()))
62 null_hack = (column.nullable and ('nullable' in list(delta.keys())))
63
63
64 if dropdefault_hack:
64 if dropdefault_hack:
65 column.server_default = sa.PassiveDefault(sa.sql.null())
65 column.server_default = sa.PassiveDefault(sa.sql.null())
66 if notnull_hack:
66 if notnull_hack:
67 column.nullable = True
67 column.nullable = True
68 colspec = self.get_column_specification(column,
68 colspec = self.get_column_specification(column,
69 override_nullable=null_hack)
69 override_nullable=null_hack)
70 if null_hack:
70 if null_hack:
71 colspec += ' NULL'
71 colspec += ' NULL'
72 if notnull_hack:
72 if notnull_hack:
73 column.nullable = False
73 column.nullable = False
74 if dropdefault_hack:
74 if dropdefault_hack:
75 column.server_default = None
75 column.server_default = None
76
76
77 self.start_alter_table(table)
77 self.start_alter_table(table)
78 self.append("MODIFY (")
78 self.append("MODIFY (")
79 self.append(colspec)
79 self.append(colspec)
80 self.append(")")
80 self.append(")")
81
81
82
82
83 class OracleConstraintCommon(object):
83 class OracleConstraintCommon(object):
84
84
85 def get_constraint_name(self, cons):
85 def get_constraint_name(self, cons):
86 # Oracle constraints can't guess their name like other DBs
86 # Oracle constraints can't guess their name like other DBs
87 if not cons.name:
87 if not cons.name:
88 raise exceptions.NotSupportedError(
88 raise exceptions.NotSupportedError(
89 "Oracle constraint names must be explicitly stated")
89 "Oracle constraint names must be explicitly stated")
90 return cons.name
90 return cons.name
91
91
92
92
93 class OracleConstraintGenerator(OracleConstraintCommon,
93 class OracleConstraintGenerator(OracleConstraintCommon,
94 ansisql.ANSIConstraintGenerator):
94 ansisql.ANSIConstraintGenerator):
95 pass
95 pass
96
96
97
97
98 class OracleConstraintDropper(OracleConstraintCommon,
98 class OracleConstraintDropper(OracleConstraintCommon,
99 ansisql.ANSIConstraintDropper):
99 ansisql.ANSIConstraintDropper):
100 pass
100 pass
101
101
102
102
103 class OracleDialect(ansisql.ANSIDialect):
103 class OracleDialect(ansisql.ANSIDialect):
104 columngenerator = OracleColumnGenerator
104 columngenerator = OracleColumnGenerator
105 columndropper = OracleColumnDropper
105 columndropper = OracleColumnDropper
106 schemachanger = OracleSchemaChanger
106 schemachanger = OracleSchemaChanger
107 constraintgenerator = OracleConstraintGenerator
107 constraintgenerator = OracleConstraintGenerator
108 constraintdropper = OracleConstraintDropper
108 constraintdropper = OracleConstraintDropper
@@ -1,668 +1,668 b''
1 """
1 """
2 Schema module providing common schema operations.
2 Schema module providing common schema operations.
3 """
3 """
4 import abc
4 import abc
5 try: # Python 3
5 try: # Python 3
6 from collections.abc import MutableMapping as DictMixin
6 from collections.abc import MutableMapping as DictMixin
7 except ImportError: # Python 2
7 except ImportError: # Python 2
8 from UserDict import DictMixin
8 from UserDict import DictMixin
9 import warnings
9 import warnings
10
10
11 import sqlalchemy
11 import sqlalchemy
12
12
13 from sqlalchemy.schema import ForeignKeyConstraint
13 from sqlalchemy.schema import ForeignKeyConstraint
14 from sqlalchemy.schema import UniqueConstraint
14 from sqlalchemy.schema import UniqueConstraint
15
15
16 from rhodecode.lib.dbmigrate.migrate.exceptions import *
16 from rhodecode.lib.dbmigrate.migrate.exceptions import *
17 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_07, SQLA_08
17 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_07, SQLA_08
18 from rhodecode.lib.dbmigrate.migrate.changeset import util
18 from rhodecode.lib.dbmigrate.migrate.changeset import util
19 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (
19 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (
20 get_engine_visitor, run_single_visitor)
20 get_engine_visitor, run_single_visitor)
21
21
22
22
23 __all__ = [
23 __all__ = [
24 'create_column',
24 'create_column',
25 'drop_column',
25 'drop_column',
26 'alter_column',
26 'alter_column',
27 'rename_table',
27 'rename_table',
28 'rename_index',
28 'rename_index',
29 'ChangesetTable',
29 'ChangesetTable',
30 'ChangesetColumn',
30 'ChangesetColumn',
31 'ChangesetIndex',
31 'ChangesetIndex',
32 'ChangesetDefaultClause',
32 'ChangesetDefaultClause',
33 'ColumnDelta',
33 'ColumnDelta',
34 ]
34 ]
35
35
36 def create_column(column, table=None, *p, **kw):
36 def create_column(column, table=None, *p, **kw):
37 """Create a column, given the table.
37 """Create a column, given the table.
38
38
39 API to :meth:`ChangesetColumn.create`.
39 API to :meth:`ChangesetColumn.create`.
40 """
40 """
41 if table is not None:
41 if table is not None:
42 return table.create_column(column, *p, **kw)
42 return table.create_column(column, *p, **kw)
43 return column.create(*p, **kw)
43 return column.create(*p, **kw)
44
44
45
45
46 def drop_column(column, table=None, *p, **kw):
46 def drop_column(column, table=None, *p, **kw):
47 """Drop a column, given the table.
47 """Drop a column, given the table.
48
48
49 API to :meth:`ChangesetColumn.drop`.
49 API to :meth:`ChangesetColumn.drop`.
50 """
50 """
51 if table is not None:
51 if table is not None:
52 return table.drop_column(column, *p, **kw)
52 return table.drop_column(column, *p, **kw)
53 return column.drop(*p, **kw)
53 return column.drop(*p, **kw)
54
54
55
55
56 def rename_table(table, name, engine=None, **kw):
56 def rename_table(table, name, engine=None, **kw):
57 """Rename a table.
57 """Rename a table.
58
58
59 If Table instance is given, engine is not used.
59 If Table instance is given, engine is not used.
60
60
61 API to :meth:`ChangesetTable.rename`.
61 API to :meth:`ChangesetTable.rename`.
62
62
63 :param table: Table to be renamed.
63 :param table: Table to be renamed.
64 :param name: New name for Table.
64 :param name: New name for Table.
65 :param engine: Engine instance.
65 :param engine: Engine instance.
66 :type table: string or Table instance
66 :type table: string or Table instance
67 :type name: string
67 :type name: string
68 :type engine: obj
68 :type engine: obj
69 """
69 """
70 table = _to_table(table, engine)
70 table = _to_table(table, engine)
71 table.rename(name, **kw)
71 table.rename(name, **kw)
72
72
73
73
74 def rename_index(index, name, table=None, engine=None, **kw):
74 def rename_index(index, name, table=None, engine=None, **kw):
75 """Rename an index.
75 """Rename an index.
76
76
77 If Index instance is given,
77 If Index instance is given,
78 table and engine are not used.
78 table and engine are not used.
79
79
80 API to :meth:`ChangesetIndex.rename`.
80 API to :meth:`ChangesetIndex.rename`.
81
81
82 :param index: Index to be renamed.
82 :param index: Index to be renamed.
83 :param name: New name for index.
83 :param name: New name for index.
84 :param table: Table to which Index is reffered.
84 :param table: Table to which Index is reffered.
85 :param engine: Engine instance.
85 :param engine: Engine instance.
86 :type index: string or Index instance
86 :type index: string or Index instance
87 :type name: string
87 :type name: string
88 :type table: string or Table instance
88 :type table: string or Table instance
89 :type engine: obj
89 :type engine: obj
90 """
90 """
91 index = _to_index(index, table, engine)
91 index = _to_index(index, table, engine)
92 index.rename(name, **kw)
92 index.rename(name, **kw)
93
93
94
94
95 def alter_column(*p, **k):
95 def alter_column(*p, **k):
96 """Alter a column.
96 """Alter a column.
97
97
98 This is a helper function that creates a :class:`ColumnDelta` and
98 This is a helper function that creates a :class:`ColumnDelta` and
99 runs it.
99 runs it.
100
100
101 :argument column:
101 :argument column:
102 The name of the column to be altered or a
102 The name of the column to be altered or a
103 :class:`ChangesetColumn` column representing it.
103 :class:`ChangesetColumn` column representing it.
104
104
105 :param table:
105 :param table:
106 A :class:`~sqlalchemy.schema.Table` or table name to
106 A :class:`~sqlalchemy.schema.Table` or table name to
107 for the table where the column will be changed.
107 for the table where the column will be changed.
108
108
109 :param engine:
109 :param engine:
110 The :class:`~sqlalchemy.engine.base.Engine` to use for table
110 The :class:`~sqlalchemy.engine.base.Engine` to use for table
111 reflection and schema alterations.
111 reflection and schema alterations.
112
112
113 :returns: A :class:`ColumnDelta` instance representing the change.
113 :returns: A :class:`ColumnDelta` instance representing the change.
114
114
115
115
116 """
116 """
117
117
118 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
118 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
119 k['table'] = p[0].table
119 k['table'] = p[0].table
120 if 'engine' not in k:
120 if 'engine' not in k:
121 k['engine'] = k['table'].bind
121 k['engine'] = k['table'].bind
122
122
123 # deprecation
123 # deprecation
124 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
124 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
125 warnings.warn(
125 warnings.warn(
126 "Passing a Column object to alter_column is deprecated."
126 "Passing a Column object to alter_column is deprecated."
127 " Just pass in keyword parameters instead.",
127 " Just pass in keyword parameters instead.",
128 MigrateDeprecationWarning
128 MigrateDeprecationWarning
129 )
129 )
130 engine = k['engine']
130 engine = k['engine']
131
131
132 # enough tests seem to break when metadata is always altered
132 # enough tests seem to break when metadata is always altered
133 # that this crutch has to be left in until they can be sorted
133 # that this crutch has to be left in until they can be sorted
134 # out
134 # out
135 k['alter_metadata']=True
135 k['alter_metadata']=True
136
136
137 delta = ColumnDelta(*p, **k)
137 delta = ColumnDelta(*p, **k)
138
138
139 visitorcallable = get_engine_visitor(engine, 'schemachanger')
139 visitorcallable = get_engine_visitor(engine, 'schemachanger')
140 engine._run_visitor(visitorcallable, delta)
140 engine._run_visitor(visitorcallable, delta)
141
141
142 return delta
142 return delta
143
143
144
144
145 def _to_table(table, engine=None):
145 def _to_table(table, engine=None):
146 """Return if instance of Table, else construct new with metadata"""
146 """Return if instance of Table, else construct new with metadata"""
147 if isinstance(table, sqlalchemy.Table):
147 if isinstance(table, sqlalchemy.Table):
148 return table
148 return table
149
149
150 # Given: table name, maybe an engine
150 # Given: table name, maybe an engine
151 meta = sqlalchemy.MetaData()
151 meta = sqlalchemy.MetaData()
152 if engine is not None:
152 if engine is not None:
153 meta.bind = engine
153 meta.bind = engine
154 return sqlalchemy.Table(table, meta)
154 return sqlalchemy.Table(table, meta)
155
155
156
156
157 def _to_index(index, table=None, engine=None):
157 def _to_index(index, table=None, engine=None):
158 """Return if instance of Index, else construct new with metadata"""
158 """Return if instance of Index, else construct new with metadata"""
159 if isinstance(index, sqlalchemy.Index):
159 if isinstance(index, sqlalchemy.Index):
160 return index
160 return index
161
161
162 # Given: index name; table name required
162 # Given: index name; table name required
163 table = _to_table(table, engine)
163 table = _to_table(table, engine)
164 ret = sqlalchemy.Index(index)
164 ret = sqlalchemy.Index(index)
165 ret.table = table
165 ret.table = table
166 return ret
166 return ret
167
167
168
168
169 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
169 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
170 """Extracts the differences between two columns/column-parameters
170 """Extracts the differences between two columns/column-parameters
171
171
172 May receive parameters arranged in several different ways:
172 May receive parameters arranged in several different ways:
173
173
174 * **current_column, new_column, \*p, \*\*kw**
174 * **current_column, new_column, \*p, \*\*kw**
175 Additional parameters can be specified to override column
175 Additional parameters can be specified to override column
176 differences.
176 differences.
177
177
178 * **current_column, \*p, \*\*kw**
178 * **current_column, \*p, \*\*kw**
179 Additional parameters alter current_column. Table name is extracted
179 Additional parameters alter current_column. Table name is extracted
180 from current_column object.
180 from current_column object.
181 Name is changed to current_column.name from current_name,
181 Name is changed to current_column.name from current_name,
182 if current_name is specified.
182 if current_name is specified.
183
183
184 * **current_col_name, \*p, \*\*kw**
184 * **current_col_name, \*p, \*\*kw**
185 Table kw must specified.
185 Table kw must specified.
186
186
187 :param table: Table at which current Column should be bound to.\
187 :param table: Table at which current Column should be bound to.\
188 If table name is given, reflection will be used.
188 If table name is given, reflection will be used.
189 :type table: string or Table instance
189 :type table: string or Table instance
190
190
191 :param metadata: A :class:`MetaData` instance to store
191 :param metadata: A :class:`MetaData` instance to store
192 reflected table names
192 reflected table names
193
193
194 :param engine: When reflecting tables, either engine or metadata must \
194 :param engine: When reflecting tables, either engine or metadata must \
195 be specified to acquire engine object.
195 be specified to acquire engine object.
196 :type engine: :class:`Engine` instance
196 :type engine: :class:`Engine` instance
197 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
197 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
198 `result_column` through :func:`dict` alike object.
198 `result_column` through :func:`dict` alike object.
199
199
200 * :class:`ColumnDelta`.result_column is altered column with new attributes
200 * :class:`ColumnDelta`.result_column is altered column with new attributes
201
201
202 * :class:`ColumnDelta`.current_name is current name of column in db
202 * :class:`ColumnDelta`.current_name is current name of column in db
203
203
204
204
205 """
205 """
206
206
207 # Column attributes that can be altered
207 # Column attributes that can be altered
208 diff_keys = ('name', 'type', 'primary_key', 'nullable',
208 diff_keys = ('name', 'type', 'primary_key', 'nullable',
209 'server_onupdate', 'server_default', 'autoincrement')
209 'server_onupdate', 'server_default', 'autoincrement')
210 diffs = dict()
210 diffs = dict()
211 __visit_name__ = 'column'
211 __visit_name__ = 'column'
212
212
213 def __init__(self, *p, **kw):
213 def __init__(self, *p, **kw):
214 # 'alter_metadata' is not a public api. It exists purely
214 # 'alter_metadata' is not a public api. It exists purely
215 # as a crutch until the tests that fail when 'alter_metadata'
215 # as a crutch until the tests that fail when 'alter_metadata'
216 # behaviour always happens can be sorted out
216 # behaviour always happens can be sorted out
217 self.alter_metadata = kw.pop("alter_metadata", False)
217 self.alter_metadata = kw.pop("alter_metadata", False)
218
218
219 self.meta = kw.pop("metadata", None)
219 self.meta = kw.pop("metadata", None)
220 self.engine = kw.pop("engine", None)
220 self.engine = kw.pop("engine", None)
221
221
222 # Things are initialized differently depending on how many column
222 # Things are initialized differently depending on how many column
223 # parameters are given. Figure out how many and call the appropriate
223 # parameters are given. Figure out how many and call the appropriate
224 # method.
224 # method.
225 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
225 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
226 # At least one column specified
226 # At least one column specified
227 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
227 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
228 # Two columns specified
228 # Two columns specified
229 diffs = self.compare_2_columns(*p, **kw)
229 diffs = self.compare_2_columns(*p, **kw)
230 else:
230 else:
231 # Exactly one column specified
231 # Exactly one column specified
232 diffs = self.compare_1_column(*p, **kw)
232 diffs = self.compare_1_column(*p, **kw)
233 else:
233 else:
234 # Zero columns specified
234 # Zero columns specified
235 if not len(p) or not isinstance(p[0], str):
235 if not len(p) or not isinstance(p[0], str):
236 raise ValueError("First argument must be column name")
236 raise ValueError("First argument must be column name")
237 diffs = self.compare_parameters(*p, **kw)
237 diffs = self.compare_parameters(*p, **kw)
238
238
239 self.apply_diffs(diffs)
239 self.apply_diffs(diffs)
240
240
241 def __repr__(self):
241 def __repr__(self):
242 return '<ColumnDelta altermetadata=%r, %s>' % (
242 return '<ColumnDelta altermetadata=%r, %s>' % (
243 self.alter_metadata,
243 self.alter_metadata,
244 super(ColumnDelta, self).__repr__()
244 super(ColumnDelta, self).__repr__()
245 )
245 )
246
246
247 def __getitem__(self, key):
247 def __getitem__(self, key):
248 if key not in self.keys():
248 if key not in list(self.keys()):
249 raise KeyError("No such diff key, available: %s" % self.diffs )
249 raise KeyError("No such diff key, available: %s" % self.diffs )
250 return getattr(self.result_column, key)
250 return getattr(self.result_column, key)
251
251
252 def __setitem__(self, key, value):
252 def __setitem__(self, key, value):
253 if key not in self.keys():
253 if key not in list(self.keys()):
254 raise KeyError("No such diff key, available: %s" % self.diffs )
254 raise KeyError("No such diff key, available: %s" % self.diffs )
255 setattr(self.result_column, key, value)
255 setattr(self.result_column, key, value)
256
256
257 def __delitem__(self, key):
257 def __delitem__(self, key):
258 raise NotImplementedError
258 raise NotImplementedError
259
259
260 def __len__(self):
260 def __len__(self):
261 raise NotImplementedError
261 raise NotImplementedError
262
262
263 def __iter__(self):
263 def __iter__(self):
264 raise NotImplementedError
264 raise NotImplementedError
265
265
266 def keys(self):
266 def keys(self):
267 return self.diffs.keys()
267 return list(self.diffs.keys())
268
268
269 def compare_parameters(self, current_name, *p, **k):
269 def compare_parameters(self, current_name, *p, **k):
270 """Compares Column objects with reflection"""
270 """Compares Column objects with reflection"""
271 self.table = k.pop('table')
271 self.table = k.pop('table')
272 self.result_column = self._table.c.get(current_name)
272 self.result_column = self._table.c.get(current_name)
273 if len(p):
273 if len(p):
274 k = self._extract_parameters(p, k, self.result_column)
274 k = self._extract_parameters(p, k, self.result_column)
275 return k
275 return k
276
276
277 def compare_1_column(self, col, *p, **k):
277 def compare_1_column(self, col, *p, **k):
278 """Compares one Column object"""
278 """Compares one Column object"""
279 self.table = k.pop('table', None)
279 self.table = k.pop('table', None)
280 if self.table is None:
280 if self.table is None:
281 self.table = col.table
281 self.table = col.table
282 self.result_column = col
282 self.result_column = col
283 if len(p):
283 if len(p):
284 k = self._extract_parameters(p, k, self.result_column)
284 k = self._extract_parameters(p, k, self.result_column)
285 return k
285 return k
286
286
287 def compare_2_columns(self, old_col, new_col, *p, **k):
287 def compare_2_columns(self, old_col, new_col, *p, **k):
288 """Compares two Column objects"""
288 """Compares two Column objects"""
289 self.process_column(new_col)
289 self.process_column(new_col)
290 self.table = k.pop('table', None)
290 self.table = k.pop('table', None)
291 # we cannot use bool() on table in SA06
291 # we cannot use bool() on table in SA06
292 if self.table is None:
292 if self.table is None:
293 self.table = old_col.table
293 self.table = old_col.table
294 if self.table is None:
294 if self.table is None:
295 new_col.table
295 new_col.table
296 self.result_column = old_col
296 self.result_column = old_col
297
297
298 # set differences
298 # set differences
299 # leave out some stuff for later comp
299 # leave out some stuff for later comp
300 for key in (set(self.diff_keys) - set(('type',))):
300 for key in (set(self.diff_keys) - set(('type',))):
301 val = getattr(new_col, key, None)
301 val = getattr(new_col, key, None)
302 if getattr(self.result_column, key, None) != val:
302 if getattr(self.result_column, key, None) != val:
303 k.setdefault(key, val)
303 k.setdefault(key, val)
304
304
305 # inspect types
305 # inspect types
306 if not self.are_column_types_eq(self.result_column.type, new_col.type):
306 if not self.are_column_types_eq(self.result_column.type, new_col.type):
307 k.setdefault('type', new_col.type)
307 k.setdefault('type', new_col.type)
308
308
309 if len(p):
309 if len(p):
310 k = self._extract_parameters(p, k, self.result_column)
310 k = self._extract_parameters(p, k, self.result_column)
311 return k
311 return k
312
312
313 def apply_diffs(self, diffs):
313 def apply_diffs(self, diffs):
314 """Populate dict and column object with new values"""
314 """Populate dict and column object with new values"""
315 self.diffs = diffs
315 self.diffs = diffs
316 for key in self.diff_keys:
316 for key in self.diff_keys:
317 if key in diffs:
317 if key in diffs:
318 setattr(self.result_column, key, diffs[key])
318 setattr(self.result_column, key, diffs[key])
319
319
320 self.process_column(self.result_column)
320 self.process_column(self.result_column)
321
321
322 # create an instance of class type if not yet
322 # create an instance of class type if not yet
323 if 'type' in diffs and callable(self.result_column.type):
323 if 'type' in diffs and callable(self.result_column.type):
324 self.result_column.type = self.result_column.type()
324 self.result_column.type = self.result_column.type()
325
325
326 # add column to the table
326 # add column to the table
327 if self.table is not None and self.alter_metadata:
327 if self.table is not None and self.alter_metadata:
328 self.result_column.add_to_table(self.table)
328 self.result_column.add_to_table(self.table)
329
329
330 def are_column_types_eq(self, old_type, new_type):
330 def are_column_types_eq(self, old_type, new_type):
331 """Compares two types to be equal"""
331 """Compares two types to be equal"""
332 ret = old_type.__class__ == new_type.__class__
332 ret = old_type.__class__ == new_type.__class__
333
333
334 # String length is a special case
334 # String length is a special case
335 if ret and isinstance(new_type, sqlalchemy.types.String):
335 if ret and isinstance(new_type, sqlalchemy.types.String):
336 ret = (getattr(old_type, 'length', None) == \
336 ret = (getattr(old_type, 'length', None) == \
337 getattr(new_type, 'length', None))
337 getattr(new_type, 'length', None))
338 return ret
338 return ret
339
339
340 def _extract_parameters(self, p, k, column):
340 def _extract_parameters(self, p, k, column):
341 """Extracts data from p and modifies diffs"""
341 """Extracts data from p and modifies diffs"""
342 p = list(p)
342 p = list(p)
343 while len(p):
343 while len(p):
344 if isinstance(p[0], str):
344 if isinstance(p[0], str):
345 k.setdefault('name', p.pop(0))
345 k.setdefault('name', p.pop(0))
346 elif isinstance(p[0], sqlalchemy.types.TypeEngine):
346 elif isinstance(p[0], sqlalchemy.types.TypeEngine):
347 k.setdefault('type', p.pop(0))
347 k.setdefault('type', p.pop(0))
348 elif callable(p[0]):
348 elif callable(p[0]):
349 p[0] = p[0]()
349 p[0] = p[0]()
350 else:
350 else:
351 break
351 break
352
352
353 if len(p):
353 if len(p):
354 new_col = column.copy_fixed()
354 new_col = column.copy_fixed()
355 new_col._init_items(*p)
355 new_col._init_items(*p)
356 k = self.compare_2_columns(column, new_col, **k)
356 k = self.compare_2_columns(column, new_col, **k)
357 return k
357 return k
358
358
359 def process_column(self, column):
359 def process_column(self, column):
360 """Processes default values for column"""
360 """Processes default values for column"""
361 # XXX: this is a snippet from SA processing of positional parameters
361 # XXX: this is a snippet from SA processing of positional parameters
362 toinit = list()
362 toinit = list()
363
363
364 if column.server_default is not None:
364 if column.server_default is not None:
365 if isinstance(column.server_default, sqlalchemy.FetchedValue):
365 if isinstance(column.server_default, sqlalchemy.FetchedValue):
366 toinit.append(column.server_default)
366 toinit.append(column.server_default)
367 else:
367 else:
368 toinit.append(sqlalchemy.DefaultClause(column.server_default))
368 toinit.append(sqlalchemy.DefaultClause(column.server_default))
369 if column.server_onupdate is not None:
369 if column.server_onupdate is not None:
370 if isinstance(column.server_onupdate, FetchedValue):
370 if isinstance(column.server_onupdate, FetchedValue):
371 toinit.append(column.server_default)
371 toinit.append(column.server_default)
372 else:
372 else:
373 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
373 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
374 for_update=True))
374 for_update=True))
375 if toinit:
375 if toinit:
376 column._init_items(*toinit)
376 column._init_items(*toinit)
377
377
378 def _get_table(self):
378 def _get_table(self):
379 return getattr(self, '_table', None)
379 return getattr(self, '_table', None)
380
380
381 def _set_table(self, table):
381 def _set_table(self, table):
382 if isinstance(table, str):
382 if isinstance(table, str):
383 if self.alter_metadata:
383 if self.alter_metadata:
384 if not self.meta:
384 if not self.meta:
385 raise ValueError("metadata must be specified for table"
385 raise ValueError("metadata must be specified for table"
386 " reflection when using alter_metadata")
386 " reflection when using alter_metadata")
387 meta = self.meta
387 meta = self.meta
388 if self.engine:
388 if self.engine:
389 meta.bind = self.engine
389 meta.bind = self.engine
390 else:
390 else:
391 if not self.engine and not self.meta:
391 if not self.engine and not self.meta:
392 raise ValueError("engine or metadata must be specified"
392 raise ValueError("engine or metadata must be specified"
393 " to reflect tables")
393 " to reflect tables")
394 if not self.engine:
394 if not self.engine:
395 self.engine = self.meta.bind
395 self.engine = self.meta.bind
396 meta = sqlalchemy.MetaData(bind=self.engine)
396 meta = sqlalchemy.MetaData(bind=self.engine)
397 self._table = sqlalchemy.Table(table, meta, autoload=True)
397 self._table = sqlalchemy.Table(table, meta, autoload=True)
398 elif isinstance(table, sqlalchemy.Table):
398 elif isinstance(table, sqlalchemy.Table):
399 self._table = table
399 self._table = table
400 if not self.alter_metadata:
400 if not self.alter_metadata:
401 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
401 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
402 def _get_result_column(self):
402 def _get_result_column(self):
403 return getattr(self, '_result_column', None)
403 return getattr(self, '_result_column', None)
404
404
405 def _set_result_column(self, column):
405 def _set_result_column(self, column):
406 """Set Column to Table based on alter_metadata evaluation."""
406 """Set Column to Table based on alter_metadata evaluation."""
407 self.process_column(column)
407 self.process_column(column)
408 if not hasattr(self, 'current_name'):
408 if not hasattr(self, 'current_name'):
409 self.current_name = column.name
409 self.current_name = column.name
410 if self.alter_metadata:
410 if self.alter_metadata:
411 self._result_column = column
411 self._result_column = column
412 else:
412 else:
413 self._result_column = column.copy_fixed()
413 self._result_column = column.copy_fixed()
414
414
415 table = property(_get_table, _set_table)
415 table = property(_get_table, _set_table)
416 result_column = property(_get_result_column, _set_result_column)
416 result_column = property(_get_result_column, _set_result_column)
417
417
418
418
419 class ChangesetTable(object):
419 class ChangesetTable(object):
420 """Changeset extensions to SQLAlchemy tables."""
420 """Changeset extensions to SQLAlchemy tables."""
421
421
422 def create_column(self, column, *p, **kw):
422 def create_column(self, column, *p, **kw):
423 """Creates a column.
423 """Creates a column.
424
424
425 The column parameter may be a column definition or the name of
425 The column parameter may be a column definition or the name of
426 a column in this table.
426 a column in this table.
427
427
428 API to :meth:`ChangesetColumn.create`
428 API to :meth:`ChangesetColumn.create`
429
429
430 :param column: Column to be created
430 :param column: Column to be created
431 :type column: Column instance or string
431 :type column: Column instance or string
432 """
432 """
433 if not isinstance(column, sqlalchemy.Column):
433 if not isinstance(column, sqlalchemy.Column):
434 # It's a column name
434 # It's a column name
435 column = getattr(self.c, str(column))
435 column = getattr(self.c, str(column))
436 column.create(table=self, *p, **kw)
436 column.create(table=self, *p, **kw)
437
437
438 def drop_column(self, column, *p, **kw):
438 def drop_column(self, column, *p, **kw):
439 """Drop a column, given its name or definition.
439 """Drop a column, given its name or definition.
440
440
441 API to :meth:`ChangesetColumn.drop`
441 API to :meth:`ChangesetColumn.drop`
442
442
443 :param column: Column to be droped
443 :param column: Column to be droped
444 :type column: Column instance or string
444 :type column: Column instance or string
445 """
445 """
446 if not isinstance(column, sqlalchemy.Column):
446 if not isinstance(column, sqlalchemy.Column):
447 # It's a column name
447 # It's a column name
448 try:
448 try:
449 column = getattr(self.c, str(column))
449 column = getattr(self.c, str(column))
450 except AttributeError:
450 except AttributeError:
451 # That column isn't part of the table. We don't need
451 # That column isn't part of the table. We don't need
452 # its entire definition to drop the column, just its
452 # its entire definition to drop the column, just its
453 # name, so create a dummy column with the same name.
453 # name, so create a dummy column with the same name.
454 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
454 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
455 column.drop(table=self, *p, **kw)
455 column.drop(table=self, *p, **kw)
456
456
457 def rename(self, name, connection=None, **kwargs):
457 def rename(self, name, connection=None, **kwargs):
458 """Rename this table.
458 """Rename this table.
459
459
460 :param name: New name of the table.
460 :param name: New name of the table.
461 :type name: string
461 :type name: string
462 :param connection: reuse connection istead of creating new one.
462 :param connection: reuse connection istead of creating new one.
463 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
463 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
464 """
464 """
465 engine = self.bind
465 engine = self.bind
466 self.new_name = name
466 self.new_name = name
467 visitorcallable = get_engine_visitor(engine, 'schemachanger')
467 visitorcallable = get_engine_visitor(engine, 'schemachanger')
468 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
468 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
469
469
470 # Fix metadata registration
470 # Fix metadata registration
471 self.name = name
471 self.name = name
472 self.deregister()
472 self.deregister()
473 self._set_parent(self.metadata)
473 self._set_parent(self.metadata)
474
474
475 def _meta_key(self):
475 def _meta_key(self):
476 """Get the meta key for this table."""
476 """Get the meta key for this table."""
477 return sqlalchemy.schema._get_table_key(self.name, self.schema)
477 return sqlalchemy.schema._get_table_key(self.name, self.schema)
478
478
479 def deregister(self):
479 def deregister(self):
480 """Remove this table from its metadata"""
480 """Remove this table from its metadata"""
481 if SQLA_07:
481 if SQLA_07:
482 self.metadata._remove_table(self.name, self.schema)
482 self.metadata._remove_table(self.name, self.schema)
483 else:
483 else:
484 key = self._meta_key()
484 key = self._meta_key()
485 meta = self.metadata
485 meta = self.metadata
486 if key in meta.tables:
486 if key in meta.tables:
487 del meta.tables[key]
487 del meta.tables[key]
488
488
489
489
490 class ChangesetColumn(object):
490 class ChangesetColumn(object):
491 """Changeset extensions to SQLAlchemy columns."""
491 """Changeset extensions to SQLAlchemy columns."""
492
492
493 def alter(self, *p, **k):
493 def alter(self, *p, **k):
494 """Makes a call to :func:`alter_column` for the column this
494 """Makes a call to :func:`alter_column` for the column this
495 method is called on.
495 method is called on.
496 """
496 """
497 if 'table' not in k:
497 if 'table' not in k:
498 k['table'] = self.table
498 k['table'] = self.table
499 if 'engine' not in k:
499 if 'engine' not in k:
500 k['engine'] = k['table'].bind
500 k['engine'] = k['table'].bind
501 return alter_column(self, *p, **k)
501 return alter_column(self, *p, **k)
502
502
503 def create(self, table=None, index_name=None, unique_name=None,
503 def create(self, table=None, index_name=None, unique_name=None,
504 primary_key_name=None, populate_default=True, connection=None, **kwargs):
504 primary_key_name=None, populate_default=True, connection=None, **kwargs):
505 """Create this column in the database.
505 """Create this column in the database.
506
506
507 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
507 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
508 for most databases.
508 for most databases.
509
509
510 :param table: Table instance to create on.
510 :param table: Table instance to create on.
511 :param index_name: Creates :class:`ChangesetIndex` on this column.
511 :param index_name: Creates :class:`ChangesetIndex` on this column.
512 :param unique_name: Creates :class:\
512 :param unique_name: Creates :class:\
513 `~migrate.changeset.constraint.UniqueConstraint` on this column.
513 `~migrate.changeset.constraint.UniqueConstraint` on this column.
514 :param primary_key_name: Creates :class:\
514 :param primary_key_name: Creates :class:\
515 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
515 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
516 :param populate_default: If True, created column will be \
516 :param populate_default: If True, created column will be \
517 populated with defaults
517 populated with defaults
518 :param connection: reuse connection istead of creating new one.
518 :param connection: reuse connection istead of creating new one.
519 :type table: Table instance
519 :type table: Table instance
520 :type index_name: string
520 :type index_name: string
521 :type unique_name: string
521 :type unique_name: string
522 :type primary_key_name: string
522 :type primary_key_name: string
523 :type populate_default: bool
523 :type populate_default: bool
524 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
524 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
525
525
526 :returns: self
526 :returns: self
527 """
527 """
528 self.populate_default = populate_default
528 self.populate_default = populate_default
529 self.index_name = index_name
529 self.index_name = index_name
530 self.unique_name = unique_name
530 self.unique_name = unique_name
531 self.primary_key_name = primary_key_name
531 self.primary_key_name = primary_key_name
532 for cons in ('index_name', 'unique_name', 'primary_key_name'):
532 for cons in ('index_name', 'unique_name', 'primary_key_name'):
533 self._check_sanity_constraints(cons)
533 self._check_sanity_constraints(cons)
534
534
535 self.add_to_table(table)
535 self.add_to_table(table)
536 engine = self.table.bind
536 engine = self.table.bind
537 visitorcallable = get_engine_visitor(engine, 'columngenerator')
537 visitorcallable = get_engine_visitor(engine, 'columngenerator')
538 engine._run_visitor(visitorcallable, self, connection, **kwargs)
538 engine._run_visitor(visitorcallable, self, connection, **kwargs)
539
539
540 # TODO: reuse existing connection
540 # TODO: reuse existing connection
541 if self.populate_default and self.default is not None:
541 if self.populate_default and self.default is not None:
542 stmt = table.update().values({self: engine._execute_default(self.default)})
542 stmt = table.update().values({self: engine._execute_default(self.default)})
543 engine.execute(stmt)
543 engine.execute(stmt)
544
544
545 return self
545 return self
546
546
547 def drop(self, table=None, connection=None, **kwargs):
547 def drop(self, table=None, connection=None, **kwargs):
548 """Drop this column from the database, leaving its table intact.
548 """Drop this column from the database, leaving its table intact.
549
549
550 ``ALTER TABLE DROP COLUMN``, for most databases.
550 ``ALTER TABLE DROP COLUMN``, for most databases.
551
551
552 :param connection: reuse connection istead of creating new one.
552 :param connection: reuse connection istead of creating new one.
553 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
553 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
554 """
554 """
555 if table is not None:
555 if table is not None:
556 self.table = table
556 self.table = table
557 engine = self.table.bind
557 engine = self.table.bind
558 visitorcallable = get_engine_visitor(engine, 'columndropper')
558 visitorcallable = get_engine_visitor(engine, 'columndropper')
559 engine._run_visitor(visitorcallable, self, connection, **kwargs)
559 engine._run_visitor(visitorcallable, self, connection, **kwargs)
560 self.remove_from_table(self.table, unset_table=False)
560 self.remove_from_table(self.table, unset_table=False)
561 self.table = None
561 self.table = None
562 return self
562 return self
563
563
564 def add_to_table(self, table):
564 def add_to_table(self, table):
565 if table is not None and self.table is None:
565 if table is not None and self.table is None:
566 if SQLA_07:
566 if SQLA_07:
567 table.append_column(self)
567 table.append_column(self)
568 else:
568 else:
569 self._set_parent(table)
569 self._set_parent(table)
570
570
571 def _col_name_in_constraint(self,cons,name):
571 def _col_name_in_constraint(self,cons,name):
572 return False
572 return False
573
573
574 def remove_from_table(self, table, unset_table=True):
574 def remove_from_table(self, table, unset_table=True):
575 # TODO: remove primary keys, constraints, etc
575 # TODO: remove primary keys, constraints, etc
576 if unset_table:
576 if unset_table:
577 self.table = None
577 self.table = None
578
578
579 to_drop = set()
579 to_drop = set()
580 for index in table.indexes:
580 for index in table.indexes:
581 columns = []
581 columns = []
582 for col in index.columns:
582 for col in index.columns:
583 if col.name!=self.name:
583 if col.name!=self.name:
584 columns.append(col)
584 columns.append(col)
585 if columns:
585 if columns:
586 index.columns = columns
586 index.columns = columns
587 if SQLA_08:
587 if SQLA_08:
588 index.expressions = columns
588 index.expressions = columns
589 else:
589 else:
590 to_drop.add(index)
590 to_drop.add(index)
591 table.indexes = table.indexes - to_drop
591 table.indexes = table.indexes - to_drop
592
592
593 to_drop = set()
593 to_drop = set()
594 for cons in table.constraints:
594 for cons in table.constraints:
595 # TODO: deal with other types of constraint
595 # TODO: deal with other types of constraint
596 if isinstance(cons,(ForeignKeyConstraint,
596 if isinstance(cons,(ForeignKeyConstraint,
597 UniqueConstraint)):
597 UniqueConstraint)):
598 for col_name in cons.columns:
598 for col_name in cons.columns:
599 if not isinstance(col_name, str):
599 if not isinstance(col_name, str):
600 col_name = col_name.name
600 col_name = col_name.name
601 if self.name==col_name:
601 if self.name==col_name:
602 to_drop.add(cons)
602 to_drop.add(cons)
603 table.constraints = table.constraints - to_drop
603 table.constraints = table.constraints - to_drop
604
604
605 if table.c.contains_column(self):
605 if table.c.contains_column(self):
606 if SQLA_07:
606 if SQLA_07:
607 table._columns.remove(self)
607 table._columns.remove(self)
608 else:
608 else:
609 table.c.remove(self)
609 table.c.remove(self)
610
610
611 # TODO: this is fixed in 0.6
611 # TODO: this is fixed in 0.6
612 def copy_fixed(self, **kw):
612 def copy_fixed(self, **kw):
613 """Create a copy of this ``Column``, with all attributes."""
613 """Create a copy of this ``Column``, with all attributes."""
614 q = util.safe_quote(self)
614 q = util.safe_quote(self)
615 return sqlalchemy.Column(self.name, self.type, self.default,
615 return sqlalchemy.Column(self.name, self.type, self.default,
616 key=self.key,
616 key=self.key,
617 primary_key=self.primary_key,
617 primary_key=self.primary_key,
618 nullable=self.nullable,
618 nullable=self.nullable,
619 quote=q,
619 quote=q,
620 index=self.index,
620 index=self.index,
621 unique=self.unique,
621 unique=self.unique,
622 onupdate=self.onupdate,
622 onupdate=self.onupdate,
623 autoincrement=self.autoincrement,
623 autoincrement=self.autoincrement,
624 server_default=self.server_default,
624 server_default=self.server_default,
625 server_onupdate=self.server_onupdate,
625 server_onupdate=self.server_onupdate,
626 *[c.copy(**kw) for c in self.constraints])
626 *[c.copy(**kw) for c in self.constraints])
627
627
628 def _check_sanity_constraints(self, name):
628 def _check_sanity_constraints(self, name):
629 """Check if constraints names are correct"""
629 """Check if constraints names are correct"""
630 obj = getattr(self, name)
630 obj = getattr(self, name)
631 if (getattr(self, name[:-5]) and not obj):
631 if (getattr(self, name[:-5]) and not obj):
632 raise InvalidConstraintError("Column.create() accepts index_name,"
632 raise InvalidConstraintError("Column.create() accepts index_name,"
633 " primary_key_name and unique_name to generate constraints")
633 " primary_key_name and unique_name to generate constraints")
634 if not isinstance(obj, str) and obj is not None:
634 if not isinstance(obj, str) and obj is not None:
635 raise InvalidConstraintError(
635 raise InvalidConstraintError(
636 "%s argument for column must be constraint name" % name)
636 "%s argument for column must be constraint name" % name)
637
637
638
638
639 class ChangesetIndex(object):
639 class ChangesetIndex(object):
640 """Changeset extensions to SQLAlchemy Indexes."""
640 """Changeset extensions to SQLAlchemy Indexes."""
641
641
642 __visit_name__ = 'index'
642 __visit_name__ = 'index'
643
643
644 def rename(self, name, connection=None, **kwargs):
644 def rename(self, name, connection=None, **kwargs):
645 """Change the name of an index.
645 """Change the name of an index.
646
646
647 :param name: New name of the Index.
647 :param name: New name of the Index.
648 :type name: string
648 :type name: string
649 :param connection: reuse connection istead of creating new one.
649 :param connection: reuse connection istead of creating new one.
650 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
650 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
651 """
651 """
652 engine = self.table.bind
652 engine = self.table.bind
653 self.new_name = name
653 self.new_name = name
654 visitorcallable = get_engine_visitor(engine, 'schemachanger')
654 visitorcallable = get_engine_visitor(engine, 'schemachanger')
655 engine._run_visitor(visitorcallable, self, connection, **kwargs)
655 engine._run_visitor(visitorcallable, self, connection, **kwargs)
656 self.name = name
656 self.name = name
657
657
658
658
659 class ChangesetDefaultClause(object):
659 class ChangesetDefaultClause(object):
660 """Implements comparison between :class:`DefaultClause` instances"""
660 """Implements comparison between :class:`DefaultClause` instances"""
661
661
662 def __eq__(self, other):
662 def __eq__(self, other):
663 if isinstance(other, self.__class__):
663 if isinstance(other, self.__class__):
664 if self.arg == other.arg:
664 if self.arg == other.arg:
665 return True
665 return True
666
666
667 def __ne__(self, other):
667 def __ne__(self, other):
668 return not self.__eq__(other)
668 return not self.__eq__(other)
@@ -1,386 +1,386 b''
1 """
1 """
2 This module provides an external API to the versioning system.
2 This module provides an external API to the versioning system.
3
3
4 .. versionchanged:: 0.6.0
4 .. versionchanged:: 0.6.0
5 :func:`migrate.versioning.api.test` and schema diff functions
5 :func:`migrate.versioning.api.test` and schema diff functions
6 changed order of positional arguments so all accept `url` and `repository`
6 changed order of positional arguments so all accept `url` and `repository`
7 as first arguments.
7 as first arguments.
8
8
9 .. versionchanged:: 0.5.4
9 .. versionchanged:: 0.5.4
10 ``--preview_sql`` displays source file when using SQL scripts.
10 ``--preview_sql`` displays source file when using SQL scripts.
11 If Python script is used, it runs the action with mocked engine and
11 If Python script is used, it runs the action with mocked engine and
12 returns captured SQL statements.
12 returns captured SQL statements.
13
13
14 .. versionchanged:: 0.5.4
14 .. versionchanged:: 0.5.4
15 Deprecated ``--echo`` parameter in favour of new
15 Deprecated ``--echo`` parameter in favour of new
16 :func:`migrate.versioning.util.construct_engine` behavior.
16 :func:`migrate.versioning.util.construct_engine` behavior.
17 """
17 """
18
18
19 # Dear migrate developers,
19 # Dear migrate developers,
20 #
20 #
21 # please do not comment this module using sphinx syntax because its
21 # please do not comment this module using sphinx syntax because its
22 # docstrings are presented as user help and most users cannot
22 # docstrings are presented as user help and most users cannot
23 # interpret sphinx annotated ReStructuredText.
23 # interpret sphinx annotated ReStructuredText.
24 #
24 #
25 # Thanks,
25 # Thanks,
26 # Jan Dittberner
26 # Jan Dittberner
27
27
28 import sys
28 import sys
29 import inspect
29 import inspect
30 import logging
30 import logging
31
31
32 from rhodecode.lib.dbmigrate.migrate import exceptions
32 from rhodecode.lib.dbmigrate.migrate import exceptions
33 from rhodecode.lib.dbmigrate.migrate.versioning import (
33 from rhodecode.lib.dbmigrate.migrate.versioning import (
34 repository, schema, version,
34 repository, schema, version,
35 script as script_ # command name conflict
35 script as script_ # command name conflict
36 )
36 )
37 from rhodecode.lib.dbmigrate.migrate.versioning.util import (
37 from rhodecode.lib.dbmigrate.migrate.versioning.util import (
38 catch_known_errors, with_engine)
38 catch_known_errors, with_engine)
39
39
40
40
41 log = logging.getLogger(__name__)
41 log = logging.getLogger(__name__)
42 command_desc = {
42 command_desc = {
43 'help': 'displays help on a given command',
43 'help': 'displays help on a given command',
44 'create': 'create an empty repository at the specified path',
44 'create': 'create an empty repository at the specified path',
45 'script': 'create an empty change Python script',
45 'script': 'create an empty change Python script',
46 'script_sql': 'create empty change SQL scripts for given database',
46 'script_sql': 'create empty change SQL scripts for given database',
47 'version': 'display the latest version available in a repository',
47 'version': 'display the latest version available in a repository',
48 'db_version': 'show the current version of the repository under version control',
48 'db_version': 'show the current version of the repository under version control',
49 'source': 'display the Python code for a particular version in this repository',
49 'source': 'display the Python code for a particular version in this repository',
50 'version_control': 'mark a database as under this repository\'s version control',
50 'version_control': 'mark a database as under this repository\'s version control',
51 'upgrade': 'upgrade a database to a later version',
51 'upgrade': 'upgrade a database to a later version',
52 'downgrade': 'downgrade a database to an earlier version',
52 'downgrade': 'downgrade a database to an earlier version',
53 'drop_version_control': 'removes version control from a database',
53 'drop_version_control': 'removes version control from a database',
54 'manage': 'creates a Python script that runs Migrate with a set of default values',
54 'manage': 'creates a Python script that runs Migrate with a set of default values',
55 'test': 'performs the upgrade and downgrade command on the given database',
55 'test': 'performs the upgrade and downgrade command on the given database',
56 'compare_model_to_db': 'compare MetaData against the current database state',
56 'compare_model_to_db': 'compare MetaData against the current database state',
57 'create_model': 'dump the current database as a Python model to stdout',
57 'create_model': 'dump the current database as a Python model to stdout',
58 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
58 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
59 'update_db_from_model': 'modify the database to match the structure of the current MetaData',
59 'update_db_from_model': 'modify the database to match the structure of the current MetaData',
60 }
60 }
61 __all__ = command_desc.keys()
61 __all__ = list(command_desc.keys())
62
62
63 Repository = repository.Repository
63 Repository = repository.Repository
64 ControlledSchema = schema.ControlledSchema
64 ControlledSchema = schema.ControlledSchema
65 VerNum = version.VerNum
65 VerNum = version.VerNum
66 PythonScript = script_.PythonScript
66 PythonScript = script_.PythonScript
67 SqlScript = script_.SqlScript
67 SqlScript = script_.SqlScript
68
68
69
69
70 # deprecated
70 # deprecated
71 def help(cmd=None, **opts):
71 def help(cmd=None, **opts):
72 """%prog help COMMAND
72 """%prog help COMMAND
73
73
74 Displays help on a given command.
74 Displays help on a given command.
75 """
75 """
76 if cmd is None:
76 if cmd is None:
77 raise exceptions.UsageError(None)
77 raise exceptions.UsageError(None)
78 try:
78 try:
79 func = globals()[cmd]
79 func = globals()[cmd]
80 except:
80 except:
81 raise exceptions.UsageError(
81 raise exceptions.UsageError(
82 "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
82 "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
83 ret = func.__doc__
83 ret = func.__doc__
84 if sys.argv[0]:
84 if sys.argv[0]:
85 ret = ret.replace('%prog', sys.argv[0])
85 ret = ret.replace('%prog', sys.argv[0])
86 return ret
86 return ret
87
87
88 @catch_known_errors
88 @catch_known_errors
89 def create(repository, name, **opts):
89 def create(repository, name, **opts):
90 """%prog create REPOSITORY_PATH NAME [--table=TABLE]
90 """%prog create REPOSITORY_PATH NAME [--table=TABLE]
91
91
92 Create an empty repository at the specified path.
92 Create an empty repository at the specified path.
93
93
94 You can specify the version_table to be used; by default, it is
94 You can specify the version_table to be used; by default, it is
95 'migrate_version'. This table is created in all version-controlled
95 'migrate_version'. This table is created in all version-controlled
96 databases.
96 databases.
97 """
97 """
98 repo_path = Repository.create(repository, name, **opts)
98 repo_path = Repository.create(repository, name, **opts)
99
99
100
100
101 @catch_known_errors
101 @catch_known_errors
102 def script(description, repository, **opts):
102 def script(description, repository, **opts):
103 """%prog script DESCRIPTION REPOSITORY_PATH
103 """%prog script DESCRIPTION REPOSITORY_PATH
104
104
105 Create an empty change script using the next unused version number
105 Create an empty change script using the next unused version number
106 appended with the given description.
106 appended with the given description.
107
107
108 For instance, manage.py script "Add initial tables" creates:
108 For instance, manage.py script "Add initial tables" creates:
109 repository/versions/001_Add_initial_tables.py
109 repository/versions/001_Add_initial_tables.py
110 """
110 """
111 repo = Repository(repository)
111 repo = Repository(repository)
112 repo.create_script(description, **opts)
112 repo.create_script(description, **opts)
113
113
114
114
115 @catch_known_errors
115 @catch_known_errors
116 def script_sql(database, description, repository, **opts):
116 def script_sql(database, description, repository, **opts):
117 """%prog script_sql DATABASE DESCRIPTION REPOSITORY_PATH
117 """%prog script_sql DATABASE DESCRIPTION REPOSITORY_PATH
118
118
119 Create empty change SQL scripts for given DATABASE, where DATABASE
119 Create empty change SQL scripts for given DATABASE, where DATABASE
120 is either specific ('postgresql', 'mysql', 'oracle', 'sqlite', etc.)
120 is either specific ('postgresql', 'mysql', 'oracle', 'sqlite', etc.)
121 or generic ('default').
121 or generic ('default').
122
122
123 For instance, manage.py script_sql postgresql description creates:
123 For instance, manage.py script_sql postgresql description creates:
124 repository/versions/001_description_postgresql_upgrade.sql and
124 repository/versions/001_description_postgresql_upgrade.sql and
125 repository/versions/001_description_postgresql_downgrade.sql
125 repository/versions/001_description_postgresql_downgrade.sql
126 """
126 """
127 repo = Repository(repository)
127 repo = Repository(repository)
128 repo.create_script_sql(database, description, **opts)
128 repo.create_script_sql(database, description, **opts)
129
129
130
130
131 def version(repository, **opts):
131 def version(repository, **opts):
132 """%prog version REPOSITORY_PATH
132 """%prog version REPOSITORY_PATH
133
133
134 Display the latest version available in a repository.
134 Display the latest version available in a repository.
135 """
135 """
136 repo = Repository(repository)
136 repo = Repository(repository)
137 return repo.latest
137 return repo.latest
138
138
139
139
140 @with_engine
140 @with_engine
141 def db_version(url, repository, **opts):
141 def db_version(url, repository, **opts):
142 """%prog db_version URL REPOSITORY_PATH
142 """%prog db_version URL REPOSITORY_PATH
143
143
144 Show the current version of the repository with the given
144 Show the current version of the repository with the given
145 connection string, under version control of the specified
145 connection string, under version control of the specified
146 repository.
146 repository.
147
147
148 The url should be any valid SQLAlchemy connection string.
148 The url should be any valid SQLAlchemy connection string.
149 """
149 """
150 engine = opts.pop('engine')
150 engine = opts.pop('engine')
151 schema = ControlledSchema(engine, repository)
151 schema = ControlledSchema(engine, repository)
152 return schema.version
152 return schema.version
153
153
154
154
155 def source(version, dest=None, repository=None, **opts):
155 def source(version, dest=None, repository=None, **opts):
156 """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
156 """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
157
157
158 Display the Python code for a particular version in this
158 Display the Python code for a particular version in this
159 repository. Save it to the file at DESTINATION or, if omitted,
159 repository. Save it to the file at DESTINATION or, if omitted,
160 send to stdout.
160 send to stdout.
161 """
161 """
162 if repository is None:
162 if repository is None:
163 raise exceptions.UsageError("A repository must be specified")
163 raise exceptions.UsageError("A repository must be specified")
164 repo = Repository(repository)
164 repo = Repository(repository)
165 ret = repo.version(version).script().source()
165 ret = repo.version(version).script().source()
166 if dest is not None:
166 if dest is not None:
167 with open(dest, 'w') as f:
167 with open(dest, 'w') as f:
168 f.write(ret)
168 f.write(ret)
169 ret = None
169 ret = None
170 return ret
170 return ret
171
171
172
172
173 def upgrade(url, repository, version=None, **opts):
173 def upgrade(url, repository, version=None, **opts):
174 """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
174 """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
175
175
176 Upgrade a database to a later version.
176 Upgrade a database to a later version.
177
177
178 This runs the upgrade() function defined in your change scripts.
178 This runs the upgrade() function defined in your change scripts.
179
179
180 By default, the database is updated to the latest available
180 By default, the database is updated to the latest available
181 version. You may specify a version instead, if you wish.
181 version. You may specify a version instead, if you wish.
182
182
183 You may preview the Python or SQL code to be executed, rather than
183 You may preview the Python or SQL code to be executed, rather than
184 actually executing it, using the appropriate 'preview' option.
184 actually executing it, using the appropriate 'preview' option.
185 """
185 """
186 err = "Cannot upgrade a database of version %s to version %s. "\
186 err = "Cannot upgrade a database of version %s to version %s. "\
187 "Try 'downgrade' instead."
187 "Try 'downgrade' instead."
188 return _migrate(url, repository, version, upgrade=True, err=err, **opts)
188 return _migrate(url, repository, version, upgrade=True, err=err, **opts)
189
189
190
190
191 def downgrade(url, repository, version, **opts):
191 def downgrade(url, repository, version, **opts):
192 """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
192 """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
193
193
194 Downgrade a database to an earlier version.
194 Downgrade a database to an earlier version.
195
195
196 This is the reverse of upgrade; this runs the downgrade() function
196 This is the reverse of upgrade; this runs the downgrade() function
197 defined in your change scripts.
197 defined in your change scripts.
198
198
199 You may preview the Python or SQL code to be executed, rather than
199 You may preview the Python or SQL code to be executed, rather than
200 actually executing it, using the appropriate 'preview' option.
200 actually executing it, using the appropriate 'preview' option.
201 """
201 """
202 err = "Cannot downgrade a database of version %s to version %s. "\
202 err = "Cannot downgrade a database of version %s to version %s. "\
203 "Try 'upgrade' instead."
203 "Try 'upgrade' instead."
204 return _migrate(url, repository, version, upgrade=False, err=err, **opts)
204 return _migrate(url, repository, version, upgrade=False, err=err, **opts)
205
205
206 @with_engine
206 @with_engine
207 def test(url, repository, **opts):
207 def test(url, repository, **opts):
208 """%prog test URL REPOSITORY_PATH [VERSION]
208 """%prog test URL REPOSITORY_PATH [VERSION]
209
209
210 Performs the upgrade and downgrade option on the given
210 Performs the upgrade and downgrade option on the given
211 database. This is not a real test and may leave the database in a
211 database. This is not a real test and may leave the database in a
212 bad state. You should therefore better run the test on a copy of
212 bad state. You should therefore better run the test on a copy of
213 your database.
213 your database.
214 """
214 """
215 engine = opts.pop('engine')
215 engine = opts.pop('engine')
216 repos = Repository(repository)
216 repos = Repository(repository)
217
217
218 # Upgrade
218 # Upgrade
219 log.info("Upgrading...")
219 log.info("Upgrading...")
220 script = repos.version(None).script(engine.name, 'upgrade')
220 script = repos.version(None).script(engine.name, 'upgrade')
221 script.run(engine, 1)
221 script.run(engine, 1)
222 log.info("done")
222 log.info("done")
223
223
224 log.info("Downgrading...")
224 log.info("Downgrading...")
225 script = repos.version(None).script(engine.name, 'downgrade')
225 script = repos.version(None).script(engine.name, 'downgrade')
226 script.run(engine, -1)
226 script.run(engine, -1)
227 log.info("done")
227 log.info("done")
228 log.info("Success")
228 log.info("Success")
229
229
230
230
231 @with_engine
231 @with_engine
232 def version_control(url, repository, version=None, **opts):
232 def version_control(url, repository, version=None, **opts):
233 """%prog version_control URL REPOSITORY_PATH [VERSION]
233 """%prog version_control URL REPOSITORY_PATH [VERSION]
234
234
235 Mark a database as under this repository's version control.
235 Mark a database as under this repository's version control.
236
236
237 Once a database is under version control, schema changes should
237 Once a database is under version control, schema changes should
238 only be done via change scripts in this repository.
238 only be done via change scripts in this repository.
239
239
240 This creates the table version_table in the database.
240 This creates the table version_table in the database.
241
241
242 The url should be any valid SQLAlchemy connection string.
242 The url should be any valid SQLAlchemy connection string.
243
243
244 By default, the database begins at version 0 and is assumed to be
244 By default, the database begins at version 0 and is assumed to be
245 empty. If the database is not empty, you may specify a version at
245 empty. If the database is not empty, you may specify a version at
246 which to begin instead. No attempt is made to verify this
246 which to begin instead. No attempt is made to verify this
247 version's correctness - the database schema is expected to be
247 version's correctness - the database schema is expected to be
248 identical to what it would be if the database were created from
248 identical to what it would be if the database were created from
249 scratch.
249 scratch.
250 """
250 """
251 engine = opts.pop('engine')
251 engine = opts.pop('engine')
252 ControlledSchema.create(engine, repository, version)
252 ControlledSchema.create(engine, repository, version)
253
253
254
254
255 @with_engine
255 @with_engine
256 def drop_version_control(url, repository, **opts):
256 def drop_version_control(url, repository, **opts):
257 """%prog drop_version_control URL REPOSITORY_PATH
257 """%prog drop_version_control URL REPOSITORY_PATH
258
258
259 Removes version control from a database.
259 Removes version control from a database.
260 """
260 """
261 engine = opts.pop('engine')
261 engine = opts.pop('engine')
262 schema = ControlledSchema(engine, repository)
262 schema = ControlledSchema(engine, repository)
263 schema.drop()
263 schema.drop()
264
264
265
265
266 def manage(file, **opts):
266 def manage(file, **opts):
267 """%prog manage FILENAME [VARIABLES...]
267 """%prog manage FILENAME [VARIABLES...]
268
268
269 Creates a script that runs Migrate with a set of default values.
269 Creates a script that runs Migrate with a set of default values.
270
270
271 For example::
271 For example::
272
272
273 %prog manage manage.py --repository=/path/to/repository \
273 %prog manage manage.py --repository=/path/to/repository \
274 --url=sqlite:///project.db
274 --url=sqlite:///project.db
275
275
276 would create the script manage.py. The following two commands
276 would create the script manage.py. The following two commands
277 would then have exactly the same results::
277 would then have exactly the same results::
278
278
279 python manage.py version
279 python manage.py version
280 %prog version --repository=/path/to/repository
280 %prog version --repository=/path/to/repository
281 """
281 """
282 Repository.create_manage_file(file, **opts)
282 Repository.create_manage_file(file, **opts)
283
283
284
284
285 @with_engine
285 @with_engine
286 def compare_model_to_db(url, repository, model, **opts):
286 def compare_model_to_db(url, repository, model, **opts):
287 """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
287 """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
288
288
289 Compare the current model (assumed to be a module level variable
289 Compare the current model (assumed to be a module level variable
290 of type sqlalchemy.MetaData) against the current database.
290 of type sqlalchemy.MetaData) against the current database.
291
291
292 NOTE: This is EXPERIMENTAL.
292 NOTE: This is EXPERIMENTAL.
293 """ # TODO: get rid of EXPERIMENTAL label
293 """ # TODO: get rid of EXPERIMENTAL label
294 engine = opts.pop('engine')
294 engine = opts.pop('engine')
295 return ControlledSchema.compare_model_to_db(engine, model, repository)
295 return ControlledSchema.compare_model_to_db(engine, model, repository)
296
296
297
297
298 @with_engine
298 @with_engine
299 def create_model(url, repository, **opts):
299 def create_model(url, repository, **opts):
300 """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
300 """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
301
301
302 Dump the current database as a Python model to stdout.
302 Dump the current database as a Python model to stdout.
303
303
304 NOTE: This is EXPERIMENTAL.
304 NOTE: This is EXPERIMENTAL.
305 """ # TODO: get rid of EXPERIMENTAL label
305 """ # TODO: get rid of EXPERIMENTAL label
306 engine = opts.pop('engine')
306 engine = opts.pop('engine')
307 declarative = opts.get('declarative', False)
307 declarative = opts.get('declarative', False)
308 return ControlledSchema.create_model(engine, repository, declarative)
308 return ControlledSchema.create_model(engine, repository, declarative)
309
309
310
310
311 @catch_known_errors
311 @catch_known_errors
312 @with_engine
312 @with_engine
313 def make_update_script_for_model(url, repository, oldmodel, model, **opts):
313 def make_update_script_for_model(url, repository, oldmodel, model, **opts):
314 """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
314 """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
315
315
316 Create a script changing the old Python model to the new (current)
316 Create a script changing the old Python model to the new (current)
317 Python model, sending to stdout.
317 Python model, sending to stdout.
318
318
319 NOTE: This is EXPERIMENTAL.
319 NOTE: This is EXPERIMENTAL.
320 """ # TODO: get rid of EXPERIMENTAL label
320 """ # TODO: get rid of EXPERIMENTAL label
321 engine = opts.pop('engine')
321 engine = opts.pop('engine')
322 return PythonScript.make_update_script_for_model(
322 return PythonScript.make_update_script_for_model(
323 engine, oldmodel, model, repository, **opts)
323 engine, oldmodel, model, repository, **opts)
324
324
325
325
326 @with_engine
326 @with_engine
327 def update_db_from_model(url, repository, model, **opts):
327 def update_db_from_model(url, repository, model, **opts):
328 """%prog update_db_from_model URL REPOSITORY_PATH MODEL
328 """%prog update_db_from_model URL REPOSITORY_PATH MODEL
329
329
330 Modify the database to match the structure of the current Python
330 Modify the database to match the structure of the current Python
331 model. This also sets the db_version number to the latest in the
331 model. This also sets the db_version number to the latest in the
332 repository.
332 repository.
333
333
334 NOTE: This is EXPERIMENTAL.
334 NOTE: This is EXPERIMENTAL.
335 """ # TODO: get rid of EXPERIMENTAL label
335 """ # TODO: get rid of EXPERIMENTAL label
336 engine = opts.pop('engine')
336 engine = opts.pop('engine')
337 schema = ControlledSchema(engine, repository)
337 schema = ControlledSchema(engine, repository)
338 schema.update_db_from_model(model)
338 schema.update_db_from_model(model)
339
339
340 @with_engine
340 @with_engine
341 def _migrate(url, repository, version, upgrade, err, **opts):
341 def _migrate(url, repository, version, upgrade, err, **opts):
342 engine = opts.pop('engine')
342 engine = opts.pop('engine')
343 url = str(engine.url)
343 url = str(engine.url)
344 schema = ControlledSchema(engine, repository)
344 schema = ControlledSchema(engine, repository)
345 version = _migrate_version(schema, version, upgrade, err)
345 version = _migrate_version(schema, version, upgrade, err)
346
346
347 changeset = schema.changeset(version)
347 changeset = schema.changeset(version)
348 for ver, change in changeset:
348 for ver, change in changeset:
349 nextver = ver + changeset.step
349 nextver = ver + changeset.step
350 log.info('%s -> %s... ', ver, nextver)
350 log.info('%s -> %s... ', ver, nextver)
351
351
352 if opts.get('preview_sql'):
352 if opts.get('preview_sql'):
353 if isinstance(change, PythonScript):
353 if isinstance(change, PythonScript):
354 log.info(change.preview_sql(url, changeset.step, **opts))
354 log.info(change.preview_sql(url, changeset.step, **opts))
355 elif isinstance(change, SqlScript):
355 elif isinstance(change, SqlScript):
356 log.info(change.source())
356 log.info(change.source())
357
357
358 elif opts.get('preview_py'):
358 elif opts.get('preview_py'):
359 if not isinstance(change, PythonScript):
359 if not isinstance(change, PythonScript):
360 raise exceptions.UsageError("Python source can be only displayed"
360 raise exceptions.UsageError("Python source can be only displayed"
361 " for python migration files")
361 " for python migration files")
362 source_ver = max(ver, nextver)
362 source_ver = max(ver, nextver)
363 module = schema.repository.version(source_ver).script().module
363 module = schema.repository.version(source_ver).script().module
364 funcname = upgrade and "upgrade" or "downgrade"
364 funcname = upgrade and "upgrade" or "downgrade"
365 func = getattr(module, funcname)
365 func = getattr(module, funcname)
366 log.info(inspect.getsource(func))
366 log.info(inspect.getsource(func))
367 else:
367 else:
368 schema.runchange(ver, change, changeset.step)
368 schema.runchange(ver, change, changeset.step)
369 log.info('done')
369 log.info('done')
370
370
371
371
372 def _migrate_version(schema, version, upgrade, err):
372 def _migrate_version(schema, version, upgrade, err):
373 if version is None:
373 if version is None:
374 return version
374 return version
375 # Version is specified: ensure we're upgrading in the right direction
375 # Version is specified: ensure we're upgrading in the right direction
376 # (current version < target version for upgrading; reverse for down)
376 # (current version < target version for upgrading; reverse for down)
377 version = VerNum(version)
377 version = VerNum(version)
378 cur = schema.version
378 cur = schema.version
379 if upgrade is not None:
379 if upgrade is not None:
380 if upgrade:
380 if upgrade:
381 direction = cur <= version
381 direction = cur <= version
382 else:
382 else:
383 direction = cur >= version
383 direction = cur >= version
384 if not direction:
384 if not direction:
385 raise exceptions.KnownError(err % (cur, version))
385 raise exceptions.KnownError(err % (cur, version))
386 return version
386 return version
@@ -1,302 +1,302 b''
1 """
1 """
2 Code to generate a Python model from a database or differences
2 Code to generate a Python model from a database or differences
3 between a model and database.
3 between a model and database.
4
4
5 Some of this is borrowed heavily from the AutoCode project at:
5 Some of this is borrowed heavily from the AutoCode project at:
6 http://code.google.com/p/sqlautocode/
6 http://code.google.com/p/sqlautocode/
7 """
7 """
8
8
9 import sys
9 import sys
10 import logging
10 import logging
11
11
12 import sqlalchemy
12 import sqlalchemy
13
13
14 import rhodecode.lib.dbmigrate.migrate
14 import rhodecode.lib.dbmigrate.migrate
15 import rhodecode.lib.dbmigrate.migrate.changeset
15 import rhodecode.lib.dbmigrate.migrate.changeset
16
16
17
17
18 log = logging.getLogger(__name__)
18 log = logging.getLogger(__name__)
19 HEADER = """
19 HEADER = """
20 ## File autogenerated by genmodel.py
20 ## File autogenerated by genmodel.py
21
21
22 from sqlalchemy import *
22 from sqlalchemy import *
23 """
23 """
24
24
25 META_DEFINITION = "meta = MetaData()"
25 META_DEFINITION = "meta = MetaData()"
26
26
27 DECLARATIVE_DEFINITION = """
27 DECLARATIVE_DEFINITION = """
28 from sqlalchemy.ext import declarative
28 from sqlalchemy.ext import declarative
29
29
30 Base = declarative.declarative_base()
30 Base = declarative.declarative_base()
31 """
31 """
32
32
33
33
34 class ModelGenerator(object):
34 class ModelGenerator(object):
35 """Various transformations from an A, B diff.
35 """Various transformations from an A, B diff.
36
36
37 In the implementation, A tends to be called the model and B
37 In the implementation, A tends to be called the model and B
38 the database (although this is not true of all diffs).
38 the database (although this is not true of all diffs).
39 The diff is directionless, but transformations apply the diff
39 The diff is directionless, but transformations apply the diff
40 in a particular direction, described in the method name.
40 in a particular direction, described in the method name.
41 """
41 """
42
42
43 def __init__(self, diff, engine, declarative=False):
43 def __init__(self, diff, engine, declarative=False):
44 self.diff = diff
44 self.diff = diff
45 self.engine = engine
45 self.engine = engine
46 self.declarative = declarative
46 self.declarative = declarative
47
47
48 def column_repr(self, col):
48 def column_repr(self, col):
49 kwarg = []
49 kwarg = []
50 if col.key != col.name:
50 if col.key != col.name:
51 kwarg.append('key')
51 kwarg.append('key')
52 if col.primary_key:
52 if col.primary_key:
53 col.primary_key = True # otherwise it dumps it as 1
53 col.primary_key = True # otherwise it dumps it as 1
54 kwarg.append('primary_key')
54 kwarg.append('primary_key')
55 if not col.nullable:
55 if not col.nullable:
56 kwarg.append('nullable')
56 kwarg.append('nullable')
57 if col.onupdate:
57 if col.onupdate:
58 kwarg.append('onupdate')
58 kwarg.append('onupdate')
59 if col.default:
59 if col.default:
60 if col.primary_key:
60 if col.primary_key:
61 # I found that PostgreSQL automatically creates a
61 # I found that PostgreSQL automatically creates a
62 # default value for the sequence, but let's not show
62 # default value for the sequence, but let's not show
63 # that.
63 # that.
64 pass
64 pass
65 else:
65 else:
66 kwarg.append('default')
66 kwarg.append('default')
67 args = ['%s=%r' % (k, getattr(col, k)) for k in kwarg]
67 args = ['%s=%r' % (k, getattr(col, k)) for k in kwarg]
68
68
69 # crs: not sure if this is good idea, but it gets rid of extra
69 # crs: not sure if this is good idea, but it gets rid of extra
70 # u''
70 # u''
71 name = col.name.encode('utf8')
71 name = col.name.encode('utf8')
72
72
73 type_ = col.type
73 type_ = col.type
74 for cls in col.type.__class__.__mro__:
74 for cls in col.type.__class__.__mro__:
75 if cls.__module__ == 'sqlalchemy.types' and \
75 if cls.__module__ == 'sqlalchemy.types' and \
76 not cls.__name__.isupper():
76 not cls.__name__.isupper():
77 if cls is not type_.__class__:
77 if cls is not type_.__class__:
78 type_ = cls()
78 type_ = cls()
79 break
79 break
80
80
81 type_repr = repr(type_)
81 type_repr = repr(type_)
82 if type_repr.endswith('()'):
82 if type_repr.endswith('()'):
83 type_repr = type_repr[:-2]
83 type_repr = type_repr[:-2]
84
84
85 constraints = [repr(cn) for cn in col.constraints]
85 constraints = [repr(cn) for cn in col.constraints]
86
86
87 data = {
87 data = {
88 'name': name,
88 'name': name,
89 'commonStuff': ', '.join([type_repr] + constraints + args),
89 'commonStuff': ', '.join([type_repr] + constraints + args),
90 }
90 }
91
91
92 if self.declarative:
92 if self.declarative:
93 return """%(name)s = Column(%(commonStuff)s)""" % data
93 return """%(name)s = Column(%(commonStuff)s)""" % data
94 else:
94 else:
95 return """Column(%(name)r, %(commonStuff)s)""" % data
95 return """Column(%(name)r, %(commonStuff)s)""" % data
96
96
97 def _getTableDefn(self, table, metaName='meta'):
97 def _getTableDefn(self, table, metaName='meta'):
98 out = []
98 out = []
99 tableName = table.name
99 tableName = table.name
100 if self.declarative:
100 if self.declarative:
101 out.append("class %(table)s(Base):" % {'table': tableName})
101 out.append("class %(table)s(Base):" % {'table': tableName})
102 out.append(" __tablename__ = '%(table)s'\n" %
102 out.append(" __tablename__ = '%(table)s'\n" %
103 {'table': tableName})
103 {'table': tableName})
104 for col in table.columns:
104 for col in table.columns:
105 out.append(" %s" % self.column_repr(col))
105 out.append(" %s" % self.column_repr(col))
106 out.append('\n')
106 out.append('\n')
107 else:
107 else:
108 out.append("%(table)s = Table('%(table)s', %(meta)s," %
108 out.append("%(table)s = Table('%(table)s', %(meta)s," %
109 {'table': tableName, 'meta': metaName})
109 {'table': tableName, 'meta': metaName})
110 for col in table.columns:
110 for col in table.columns:
111 out.append(" %s," % self.column_repr(col))
111 out.append(" %s," % self.column_repr(col))
112 out.append(")\n")
112 out.append(")\n")
113 return out
113 return out
114
114
115 def _get_tables(self,missingA=False,missingB=False,modified=False):
115 def _get_tables(self,missingA=False,missingB=False,modified=False):
116 to_process = []
116 to_process = []
117 for bool_,names,metadata in (
117 for bool_,names,metadata in (
118 (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
118 (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
119 (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
119 (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
120 (modified,self.diff.tables_different,self.diff.metadataA),
120 (modified,self.diff.tables_different,self.diff.metadataA),
121 ):
121 ):
122 if bool_:
122 if bool_:
123 for name in names:
123 for name in names:
124 yield metadata.tables.get(name)
124 yield metadata.tables.get(name)
125
125
126 def _genModelHeader(self, tables):
126 def _genModelHeader(self, tables):
127 out = []
127 out = []
128 import_index = []
128 import_index = []
129
129
130 out.append(HEADER)
130 out.append(HEADER)
131
131
132 for table in tables:
132 for table in tables:
133 for col in table.columns:
133 for col in table.columns:
134 if "dialects" in col.type.__module__ and \
134 if "dialects" in col.type.__module__ and \
135 col.type.__class__ not in import_index:
135 col.type.__class__ not in import_index:
136 out.append("from " + col.type.__module__ +
136 out.append("from " + col.type.__module__ +
137 " import " + col.type.__class__.__name__)
137 " import " + col.type.__class__.__name__)
138 import_index.append(col.type.__class__)
138 import_index.append(col.type.__class__)
139
139
140 out.append("")
140 out.append("")
141
141
142 if self.declarative:
142 if self.declarative:
143 out.append(DECLARATIVE_DEFINITION)
143 out.append(DECLARATIVE_DEFINITION)
144 else:
144 else:
145 out.append(META_DEFINITION)
145 out.append(META_DEFINITION)
146 out.append("")
146 out.append("")
147
147
148 return out
148 return out
149
149
150 def genBDefinition(self):
150 def genBDefinition(self):
151 """Generates the source code for a definition of B.
151 """Generates the source code for a definition of B.
152
152
153 Assumes a diff where A is empty.
153 Assumes a diff where A is empty.
154
154
155 Was: toPython. Assume database (B) is current and model (A) is empty.
155 Was: toPython. Assume database (B) is current and model (A) is empty.
156 """
156 """
157
157
158 out = []
158 out = []
159 out.extend(self._genModelHeader(self._get_tables(missingA=True)))
159 out.extend(self._genModelHeader(self._get_tables(missingA=True)))
160 for table in self._get_tables(missingA=True):
160 for table in self._get_tables(missingA=True):
161 out.extend(self._getTableDefn(table))
161 out.extend(self._getTableDefn(table))
162 return '\n'.join(out)
162 return '\n'.join(out)
163
163
164 def genB2AMigration(self, indent=' '):
164 def genB2AMigration(self, indent=' '):
165 """Generate a migration from B to A.
165 """Generate a migration from B to A.
166
166
167 Was: toUpgradeDowngradePython
167 Was: toUpgradeDowngradePython
168 Assume model (A) is most current and database (B) is out-of-date.
168 Assume model (A) is most current and database (B) is out-of-date.
169 """
169 """
170
170
171 decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema',
171 decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema',
172 'pre_meta = MetaData()',
172 'pre_meta = MetaData()',
173 'post_meta = MetaData()',
173 'post_meta = MetaData()',
174 ]
174 ]
175 upgradeCommands = ['pre_meta.bind = migrate_engine',
175 upgradeCommands = ['pre_meta.bind = migrate_engine',
176 'post_meta.bind = migrate_engine']
176 'post_meta.bind = migrate_engine']
177 downgradeCommands = list(upgradeCommands)
177 downgradeCommands = list(upgradeCommands)
178
178
179 for tn in self.diff.tables_missing_from_A:
179 for tn in self.diff.tables_missing_from_A:
180 pre_table = self.diff.metadataB.tables[tn]
180 pre_table = self.diff.metadataB.tables[tn]
181 decls.extend(self._getTableDefn(pre_table, metaName='pre_meta'))
181 decls.extend(self._getTableDefn(pre_table, metaName='pre_meta'))
182 upgradeCommands.append(
182 upgradeCommands.append(
183 "pre_meta.tables[%(table)r].drop()" % {'table': tn})
183 "pre_meta.tables[%(table)r].drop()" % {'table': tn})
184 downgradeCommands.append(
184 downgradeCommands.append(
185 "pre_meta.tables[%(table)r].create()" % {'table': tn})
185 "pre_meta.tables[%(table)r].create()" % {'table': tn})
186
186
187 for tn in self.diff.tables_missing_from_B:
187 for tn in self.diff.tables_missing_from_B:
188 post_table = self.diff.metadataA.tables[tn]
188 post_table = self.diff.metadataA.tables[tn]
189 decls.extend(self._getTableDefn(post_table, metaName='post_meta'))
189 decls.extend(self._getTableDefn(post_table, metaName='post_meta'))
190 upgradeCommands.append(
190 upgradeCommands.append(
191 "post_meta.tables[%(table)r].create()" % {'table': tn})
191 "post_meta.tables[%(table)r].create()" % {'table': tn})
192 downgradeCommands.append(
192 downgradeCommands.append(
193 "post_meta.tables[%(table)r].drop()" % {'table': tn})
193 "post_meta.tables[%(table)r].drop()" % {'table': tn})
194
194
195 for (tn, td) in self.diff.tables_different.items():
195 for (tn, td) in list(self.diff.tables_different.items()):
196 if td.columns_missing_from_A or td.columns_different:
196 if td.columns_missing_from_A or td.columns_different:
197 pre_table = self.diff.metadataB.tables[tn]
197 pre_table = self.diff.metadataB.tables[tn]
198 decls.extend(self._getTableDefn(
198 decls.extend(self._getTableDefn(
199 pre_table, metaName='pre_meta'))
199 pre_table, metaName='pre_meta'))
200 if td.columns_missing_from_B or td.columns_different:
200 if td.columns_missing_from_B or td.columns_different:
201 post_table = self.diff.metadataA.tables[tn]
201 post_table = self.diff.metadataA.tables[tn]
202 decls.extend(self._getTableDefn(
202 decls.extend(self._getTableDefn(
203 post_table, metaName='post_meta'))
203 post_table, metaName='post_meta'))
204
204
205 for col in td.columns_missing_from_A:
205 for col in td.columns_missing_from_A:
206 upgradeCommands.append(
206 upgradeCommands.append(
207 'pre_meta.tables[%r].columns[%r].drop()' % (tn, col))
207 'pre_meta.tables[%r].columns[%r].drop()' % (tn, col))
208 downgradeCommands.append(
208 downgradeCommands.append(
209 'pre_meta.tables[%r].columns[%r].create()' % (tn, col))
209 'pre_meta.tables[%r].columns[%r].create()' % (tn, col))
210 for col in td.columns_missing_from_B:
210 for col in td.columns_missing_from_B:
211 upgradeCommands.append(
211 upgradeCommands.append(
212 'post_meta.tables[%r].columns[%r].create()' % (tn, col))
212 'post_meta.tables[%r].columns[%r].create()' % (tn, col))
213 downgradeCommands.append(
213 downgradeCommands.append(
214 'post_meta.tables[%r].columns[%r].drop()' % (tn, col))
214 'post_meta.tables[%r].columns[%r].drop()' % (tn, col))
215 for modelCol, databaseCol, modelDecl, databaseDecl in td.columns_different:
215 for modelCol, databaseCol, modelDecl, databaseDecl in td.columns_different:
216 upgradeCommands.append(
216 upgradeCommands.append(
217 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
217 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
218 tn, modelCol.name, databaseCol.name))
218 tn, modelCol.name, databaseCol.name))
219 downgradeCommands.append(
219 downgradeCommands.append(
220 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
220 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
221 tn, modelCol.name, databaseCol.name))
221 tn, modelCol.name, databaseCol.name))
222
222
223 return (
223 return (
224 '\n'.join(decls),
224 '\n'.join(decls),
225 '\n'.join('%s%s' % (indent, line) for line in upgradeCommands),
225 '\n'.join('%s%s' % (indent, line) for line in upgradeCommands),
226 '\n'.join('%s%s' % (indent, line) for line in downgradeCommands))
226 '\n'.join('%s%s' % (indent, line) for line in downgradeCommands))
227
227
228 def _db_can_handle_this_change(self,td):
228 def _db_can_handle_this_change(self,td):
229 """Check if the database can handle going from B to A."""
229 """Check if the database can handle going from B to A."""
230
230
231 if (td.columns_missing_from_B
231 if (td.columns_missing_from_B
232 and not td.columns_missing_from_A
232 and not td.columns_missing_from_A
233 and not td.columns_different):
233 and not td.columns_different):
234 # Even sqlite can handle column additions.
234 # Even sqlite can handle column additions.
235 return True
235 return True
236 else:
236 else:
237 return not self.engine.url.drivername.startswith('sqlite')
237 return not self.engine.url.drivername.startswith('sqlite')
238
238
239 def runB2A(self):
239 def runB2A(self):
240 """Goes from B to A.
240 """Goes from B to A.
241
241
242 Was: applyModel. Apply model (A) to current database (B).
242 Was: applyModel. Apply model (A) to current database (B).
243 """
243 """
244
244
245 meta = sqlalchemy.MetaData(self.engine)
245 meta = sqlalchemy.MetaData(self.engine)
246
246
247 for table in self._get_tables(missingA=True):
247 for table in self._get_tables(missingA=True):
248 table = table.tometadata(meta)
248 table = table.tometadata(meta)
249 table.drop()
249 table.drop()
250 for table in self._get_tables(missingB=True):
250 for table in self._get_tables(missingB=True):
251 table = table.tometadata(meta)
251 table = table.tometadata(meta)
252 table.create()
252 table.create()
253 for modelTable in self._get_tables(modified=True):
253 for modelTable in self._get_tables(modified=True):
254 tableName = modelTable.name
254 tableName = modelTable.name
255 modelTable = modelTable.tometadata(meta)
255 modelTable = modelTable.tometadata(meta)
256 dbTable = self.diff.metadataB.tables[tableName]
256 dbTable = self.diff.metadataB.tables[tableName]
257
257
258 td = self.diff.tables_different[tableName]
258 td = self.diff.tables_different[tableName]
259
259
260 if self._db_can_handle_this_change(td):
260 if self._db_can_handle_this_change(td):
261
261
262 for col in td.columns_missing_from_B:
262 for col in td.columns_missing_from_B:
263 modelTable.columns[col].create()
263 modelTable.columns[col].create()
264 for col in td.columns_missing_from_A:
264 for col in td.columns_missing_from_A:
265 dbTable.columns[col].drop()
265 dbTable.columns[col].drop()
266 # XXX handle column changes here.
266 # XXX handle column changes here.
267 else:
267 else:
268 # Sqlite doesn't support drop column, so you have to
268 # Sqlite doesn't support drop column, so you have to
269 # do more: create temp table, copy data to it, drop
269 # do more: create temp table, copy data to it, drop
270 # old table, create new table, copy data back.
270 # old table, create new table, copy data back.
271 #
271 #
272 # I wonder if this is guaranteed to be unique?
272 # I wonder if this is guaranteed to be unique?
273 tempName = '_temp_%s' % modelTable.name
273 tempName = '_temp_%s' % modelTable.name
274
274
275 def getCopyStatement():
275 def getCopyStatement():
276 preparer = self.engine.dialect.preparer
276 preparer = self.engine.dialect.preparer
277 commonCols = []
277 commonCols = []
278 for modelCol in modelTable.columns:
278 for modelCol in modelTable.columns:
279 if modelCol.name in dbTable.columns:
279 if modelCol.name in dbTable.columns:
280 commonCols.append(modelCol.name)
280 commonCols.append(modelCol.name)
281 commonColsStr = ', '.join(commonCols)
281 commonColsStr = ', '.join(commonCols)
282 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
282 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
283 (tableName, commonColsStr, commonColsStr, tempName)
283 (tableName, commonColsStr, commonColsStr, tempName)
284
284
285 # Move the data in one transaction, so that we don't
285 # Move the data in one transaction, so that we don't
286 # leave the database in a nasty state.
286 # leave the database in a nasty state.
287 connection = self.engine.connect()
287 connection = self.engine.connect()
288 trans = connection.begin()
288 trans = connection.begin()
289 try:
289 try:
290 connection.execute(
290 connection.execute(
291 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
291 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
292 (tempName, modelTable.name))
292 (tempName, modelTable.name))
293 # make sure the drop takes place inside our
293 # make sure the drop takes place inside our
294 # transaction with the bind parameter
294 # transaction with the bind parameter
295 modelTable.drop(bind=connection)
295 modelTable.drop(bind=connection)
296 modelTable.create(bind=connection)
296 modelTable.create(bind=connection)
297 connection.execute(getCopyStatement())
297 connection.execute(getCopyStatement())
298 connection.execute('DROP TABLE %s' % tempName)
298 connection.execute('DROP TABLE %s' % tempName)
299 trans.commit()
299 trans.commit()
300 except:
300 except:
301 trans.rollback()
301 trans.rollback()
302 raise
302 raise
@@ -1,100 +1,100 b''
1 """
1 """
2 Script to migrate repository from sqlalchemy <= 0.4.4 to the new
2 Script to migrate repository from sqlalchemy <= 0.4.4 to the new
3 repository schema. This shouldn't use any other migrate modules, so
3 repository schema. This shouldn't use any other migrate modules, so
4 that it can work in any version.
4 that it can work in any version.
5 """
5 """
6
6
7 import os
7 import os
8 import sys
8 import sys
9 import logging
9 import logging
10
10
11 log = logging.getLogger(__name__)
11 log = logging.getLogger(__name__)
12
12
13
13
14 def usage():
14 def usage():
15 """Gives usage information."""
15 """Gives usage information."""
16 print("""Usage: %(prog)s repository-to-migrate
16 print(("""Usage: %(prog)s repository-to-migrate
17
17
18 Upgrade your repository to the new flat format.
18 Upgrade your repository to the new flat format.
19
19
20 NOTE: You should probably make a backup before running this.
20 NOTE: You should probably make a backup before running this.
21 """ % {'prog': sys.argv[0]})
21 """ % {'prog': sys.argv[0]}))
22
22
23 sys.exit(1)
23 sys.exit(1)
24
24
25
25
26 def delete_file(filepath):
26 def delete_file(filepath):
27 """Deletes a file and prints a message."""
27 """Deletes a file and prints a message."""
28 log.info('Deleting file: %s', filepath)
28 log.info('Deleting file: %s', filepath)
29 os.remove(filepath)
29 os.remove(filepath)
30
30
31
31
32 def move_file(src, tgt):
32 def move_file(src, tgt):
33 """Moves a file and prints a message."""
33 """Moves a file and prints a message."""
34 log.info('Moving file %s to %s', src, tgt)
34 log.info('Moving file %s to %s', src, tgt)
35 if os.path.exists(tgt):
35 if os.path.exists(tgt):
36 raise Exception(
36 raise Exception(
37 'Cannot move file %s because target %s already exists' % \
37 'Cannot move file %s because target %s already exists' % \
38 (src, tgt))
38 (src, tgt))
39 os.rename(src, tgt)
39 os.rename(src, tgt)
40
40
41
41
42 def delete_directory(dirpath):
42 def delete_directory(dirpath):
43 """Delete a directory and print a message."""
43 """Delete a directory and print a message."""
44 log.info('Deleting directory: %s', dirpath)
44 log.info('Deleting directory: %s', dirpath)
45 os.rmdir(dirpath)
45 os.rmdir(dirpath)
46
46
47
47
48 def migrate_repository(repos):
48 def migrate_repository(repos):
49 """Does the actual migration to the new repository format."""
49 """Does the actual migration to the new repository format."""
50 log.info('Migrating repository at: %s to new format', repos)
50 log.info('Migrating repository at: %s to new format', repos)
51 versions = '%s/versions' % repos
51 versions = '%s/versions' % repos
52 dirs = os.listdir(versions)
52 dirs = os.listdir(versions)
53 # Only use int's in list.
53 # Only use int's in list.
54 numdirs = [int(dirname) for dirname in dirs if dirname.isdigit()]
54 numdirs = [int(dirname) for dirname in dirs if dirname.isdigit()]
55 numdirs.sort() # Sort list.
55 numdirs.sort() # Sort list.
56 for dirname in numdirs:
56 for dirname in numdirs:
57 origdir = '%s/%s' % (versions, dirname)
57 origdir = '%s/%s' % (versions, dirname)
58 log.info('Working on directory: %s', origdir)
58 log.info('Working on directory: %s', origdir)
59 files = os.listdir(origdir)
59 files = os.listdir(origdir)
60 files.sort()
60 files.sort()
61 for filename in files:
61 for filename in files:
62 # Delete compiled Python files.
62 # Delete compiled Python files.
63 if filename.endswith('.pyc') or filename.endswith('.pyo'):
63 if filename.endswith('.pyc') or filename.endswith('.pyo'):
64 delete_file('%s/%s' % (origdir, filename))
64 delete_file('%s/%s' % (origdir, filename))
65
65
66 # Delete empty __init__.py files.
66 # Delete empty __init__.py files.
67 origfile = '%s/__init__.py' % origdir
67 origfile = '%s/__init__.py' % origdir
68 if os.path.exists(origfile) and len(open(origfile).read()) == 0:
68 if os.path.exists(origfile) and len(open(origfile).read()) == 0:
69 delete_file(origfile)
69 delete_file(origfile)
70
70
71 # Move sql upgrade scripts.
71 # Move sql upgrade scripts.
72 if filename.endswith('.sql'):
72 if filename.endswith('.sql'):
73 version, dbms, operation = filename.split('.', 3)[0:3]
73 version, dbms, operation = filename.split('.', 3)[0:3]
74 origfile = '%s/%s' % (origdir, filename)
74 origfile = '%s/%s' % (origdir, filename)
75 # For instance: 2.postgres.upgrade.sql ->
75 # For instance: 2.postgres.upgrade.sql ->
76 # 002_postgres_upgrade.sql
76 # 002_postgres_upgrade.sql
77 tgtfile = '%s/%03d_%s_%s.sql' % (
77 tgtfile = '%s/%03d_%s_%s.sql' % (
78 versions, int(version), dbms, operation)
78 versions, int(version), dbms, operation)
79 move_file(origfile, tgtfile)
79 move_file(origfile, tgtfile)
80
80
81 # Move Python upgrade script.
81 # Move Python upgrade script.
82 pyfile = '%s.py' % dirname
82 pyfile = '%s.py' % dirname
83 pyfilepath = '%s/%s' % (origdir, pyfile)
83 pyfilepath = '%s/%s' % (origdir, pyfile)
84 if os.path.exists(pyfilepath):
84 if os.path.exists(pyfilepath):
85 tgtfile = '%s/%03d.py' % (versions, int(dirname))
85 tgtfile = '%s/%03d.py' % (versions, int(dirname))
86 move_file(pyfilepath, tgtfile)
86 move_file(pyfilepath, tgtfile)
87
87
88 # Try to remove directory. Will fail if it's not empty.
88 # Try to remove directory. Will fail if it's not empty.
89 delete_directory(origdir)
89 delete_directory(origdir)
90
90
91
91
92 def main():
92 def main():
93 """Main function to be called when using this script."""
93 """Main function to be called when using this script."""
94 if len(sys.argv) != 2:
94 if len(sys.argv) != 2:
95 usage()
95 usage()
96 migrate_repository(sys.argv[1])
96 migrate_repository(sys.argv[1])
97
97
98
98
99 if __name__ == '__main__':
99 if __name__ == '__main__':
100 main()
100 main()
@@ -1,243 +1,243 b''
1 """
1 """
2 SQLAlchemy migrate repository management.
2 SQLAlchemy migrate repository management.
3 """
3 """
4 import os
4 import os
5 import shutil
5 import shutil
6 import string
6 import string
7 import logging
7 import logging
8
8
9 from pkg_resources import resource_filename
9 from pkg_resources import resource_filename
10 from tempita import Template as TempitaTemplate
10 from tempita import Template as TempitaTemplate
11
11
12 from rhodecode.lib.dbmigrate.migrate import exceptions
12 from rhodecode.lib.dbmigrate.migrate import exceptions
13 from rhodecode.lib.dbmigrate.migrate.versioning import version, pathed, cfgparse
13 from rhodecode.lib.dbmigrate.migrate.versioning import version, pathed, cfgparse
14 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
14 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
15 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
15 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
16
16
17
17
18 log = logging.getLogger(__name__)
18 log = logging.getLogger(__name__)
19
19
20 class Changeset(dict):
20 class Changeset(dict):
21 """A collection of changes to be applied to a database.
21 """A collection of changes to be applied to a database.
22
22
23 Changesets are bound to a repository and manage a set of
23 Changesets are bound to a repository and manage a set of
24 scripts from that repository.
24 scripts from that repository.
25
25
26 Behaves like a dict, for the most part. Keys are ordered based on step value.
26 Behaves like a dict, for the most part. Keys are ordered based on step value.
27 """
27 """
28
28
29 def __init__(self, start, *changes, **k):
29 def __init__(self, start, *changes, **k):
30 """
30 """
31 Give a start version; step must be explicitly stated.
31 Give a start version; step must be explicitly stated.
32 """
32 """
33 self.step = k.pop('step', 1)
33 self.step = k.pop('step', 1)
34 self.start = version.VerNum(start)
34 self.start = version.VerNum(start)
35 self.end = self.start
35 self.end = self.start
36 for change in changes:
36 for change in changes:
37 self.add(change)
37 self.add(change)
38
38
39 def __iter__(self):
39 def __iter__(self):
40 return iter(self.items())
40 return iter(list(self.items()))
41
41
42 def keys(self):
42 def keys(self):
43 """
43 """
44 In a series of upgrades x -> y, keys are version x. Sorted.
44 In a series of upgrades x -> y, keys are version x. Sorted.
45 """
45 """
46 ret = super(Changeset, self).keys()
46 ret = list(super(Changeset, self).keys())
47 # Reverse order if downgrading
47 # Reverse order if downgrading
48 ret.sort(reverse=(self.step < 1))
48 ret.sort(reverse=(self.step < 1))
49 return ret
49 return ret
50
50
51 def values(self):
51 def values(self):
52 return [self[k] for k in self.keys()]
52 return [self[k] for k in list(self.keys())]
53
53
54 def items(self):
54 def items(self):
55 return zip(self.keys(), self.values())
55 return list(zip(list(self.keys()), list(self.values())))
56
56
57 def add(self, change):
57 def add(self, change):
58 """Add new change to changeset"""
58 """Add new change to changeset"""
59 key = self.end
59 key = self.end
60 self.end += self.step
60 self.end += self.step
61 self[key] = change
61 self[key] = change
62
62
63 def run(self, *p, **k):
63 def run(self, *p, **k):
64 """Run the changeset scripts"""
64 """Run the changeset scripts"""
65 for version, script in self:
65 for version, script in self:
66 script.run(*p, **k)
66 script.run(*p, **k)
67
67
68
68
69 class Repository(pathed.Pathed):
69 class Repository(pathed.Pathed):
70 """A project's change script repository"""
70 """A project's change script repository"""
71
71
72 _config = 'migrate.cfg'
72 _config = 'migrate.cfg'
73 _versions = 'versions'
73 _versions = 'versions'
74
74
75 def __init__(self, path):
75 def __init__(self, path):
76 log.debug('Loading repository %s...', path)
76 log.debug('Loading repository %s...', path)
77 self.verify(path)
77 self.verify(path)
78 super(Repository, self).__init__(path)
78 super(Repository, self).__init__(path)
79 self.config = cfgparse.Config(os.path.join(self.path, self._config))
79 self.config = cfgparse.Config(os.path.join(self.path, self._config))
80 self.versions = version.Collection(os.path.join(self.path,
80 self.versions = version.Collection(os.path.join(self.path,
81 self._versions))
81 self._versions))
82 log.debug('Repository %s loaded successfully', path)
82 log.debug('Repository %s loaded successfully', path)
83 log.debug('Config: %r', self.config.to_dict())
83 log.debug('Config: %r', self.config.to_dict())
84
84
85 @classmethod
85 @classmethod
86 def verify(cls, path):
86 def verify(cls, path):
87 """
87 """
88 Ensure the target path is a valid repository.
88 Ensure the target path is a valid repository.
89
89
90 :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
90 :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
91 """
91 """
92 # Ensure the existence of required files
92 # Ensure the existence of required files
93 try:
93 try:
94 cls.require_found(path)
94 cls.require_found(path)
95 cls.require_found(os.path.join(path, cls._config))
95 cls.require_found(os.path.join(path, cls._config))
96 cls.require_found(os.path.join(path, cls._versions))
96 cls.require_found(os.path.join(path, cls._versions))
97 except exceptions.PathNotFoundError as e:
97 except exceptions.PathNotFoundError as e:
98 raise exceptions.InvalidRepositoryError(path)
98 raise exceptions.InvalidRepositoryError(path)
99
99
100 @classmethod
100 @classmethod
101 def prepare_config(cls, tmpl_dir, name, options=None):
101 def prepare_config(cls, tmpl_dir, name, options=None):
102 """
102 """
103 Prepare a project configuration file for a new project.
103 Prepare a project configuration file for a new project.
104
104
105 :param tmpl_dir: Path to Repository template
105 :param tmpl_dir: Path to Repository template
106 :param config_file: Name of the config file in Repository template
106 :param config_file: Name of the config file in Repository template
107 :param name: Repository name
107 :param name: Repository name
108 :type tmpl_dir: string
108 :type tmpl_dir: string
109 :type config_file: string
109 :type config_file: string
110 :type name: string
110 :type name: string
111 :returns: Populated config file
111 :returns: Populated config file
112 """
112 """
113 if options is None:
113 if options is None:
114 options = {}
114 options = {}
115 options.setdefault('version_table', 'migrate_version')
115 options.setdefault('version_table', 'migrate_version')
116 options.setdefault('repository_id', name)
116 options.setdefault('repository_id', name)
117 options.setdefault('required_dbs', [])
117 options.setdefault('required_dbs', [])
118 options.setdefault('use_timestamp_numbering', False)
118 options.setdefault('use_timestamp_numbering', False)
119
119
120 with open(os.path.join(tmpl_dir, cls._config)) as f:
120 with open(os.path.join(tmpl_dir, cls._config)) as f:
121 tmpl = f.read()
121 tmpl = f.read()
122 ret = TempitaTemplate(tmpl).substitute(options)
122 ret = TempitaTemplate(tmpl).substitute(options)
123
123
124 # cleanup
124 # cleanup
125 del options['__template_name__']
125 del options['__template_name__']
126
126
127 return ret
127 return ret
128
128
129 @classmethod
129 @classmethod
130 def create(cls, path, name, **opts):
130 def create(cls, path, name, **opts):
131 """Create a repository at a specified path"""
131 """Create a repository at a specified path"""
132 cls.require_notfound(path)
132 cls.require_notfound(path)
133 theme = opts.pop('templates_theme', None)
133 theme = opts.pop('templates_theme', None)
134 t_path = opts.pop('templates_path', None)
134 t_path = opts.pop('templates_path', None)
135
135
136 # Create repository
136 # Create repository
137 tmpl_dir = Template(t_path).get_repository(theme=theme)
137 tmpl_dir = Template(t_path).get_repository(theme=theme)
138 shutil.copytree(tmpl_dir, path)
138 shutil.copytree(tmpl_dir, path)
139
139
140 # Edit config defaults
140 # Edit config defaults
141 config_text = cls.prepare_config(tmpl_dir, name, options=opts)
141 config_text = cls.prepare_config(tmpl_dir, name, options=opts)
142 with open(os.path.join(path, cls._config), 'w') as fd:
142 with open(os.path.join(path, cls._config), 'w') as fd:
143 fd.write(config_text)
143 fd.write(config_text)
144
144
145 opts['repository_name'] = name
145 opts['repository_name'] = name
146
146
147 # Create a management script
147 # Create a management script
148 manager = os.path.join(path, 'manage.py')
148 manager = os.path.join(path, 'manage.py')
149 Repository.create_manage_file(manager, templates_theme=theme,
149 Repository.create_manage_file(manager, templates_theme=theme,
150 templates_path=t_path, **opts)
150 templates_path=t_path, **opts)
151
151
152 return cls(path)
152 return cls(path)
153
153
154 def create_script(self, description, **k):
154 def create_script(self, description, **k):
155 """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
155 """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
156
156
157 k['use_timestamp_numbering'] = self.use_timestamp_numbering
157 k['use_timestamp_numbering'] = self.use_timestamp_numbering
158 self.versions.create_new_python_version(description, **k)
158 self.versions.create_new_python_version(description, **k)
159
159
160 def create_script_sql(self, database, description, **k):
160 def create_script_sql(self, database, description, **k):
161 """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
161 """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
162 k['use_timestamp_numbering'] = self.use_timestamp_numbering
162 k['use_timestamp_numbering'] = self.use_timestamp_numbering
163 self.versions.create_new_sql_version(database, description, **k)
163 self.versions.create_new_sql_version(database, description, **k)
164
164
165 @property
165 @property
166 def latest(self):
166 def latest(self):
167 """API to :attr:`migrate.versioning.version.Collection.latest`"""
167 """API to :attr:`migrate.versioning.version.Collection.latest`"""
168 return self.versions.latest
168 return self.versions.latest
169
169
170 @property
170 @property
171 def version_table(self):
171 def version_table(self):
172 """Returns version_table name specified in config"""
172 """Returns version_table name specified in config"""
173 return self.config.get('db_settings', 'version_table')
173 return self.config.get('db_settings', 'version_table')
174
174
175 @property
175 @property
176 def id(self):
176 def id(self):
177 """Returns repository id specified in config"""
177 """Returns repository id specified in config"""
178 return self.config.get('db_settings', 'repository_id')
178 return self.config.get('db_settings', 'repository_id')
179
179
180 @property
180 @property
181 def use_timestamp_numbering(self):
181 def use_timestamp_numbering(self):
182 """Returns use_timestamp_numbering specified in config"""
182 """Returns use_timestamp_numbering specified in config"""
183 if self.config.has_option('db_settings', 'use_timestamp_numbering'):
183 if self.config.has_option('db_settings', 'use_timestamp_numbering'):
184 return self.config.getboolean('db_settings', 'use_timestamp_numbering')
184 return self.config.getboolean('db_settings', 'use_timestamp_numbering')
185 return False
185 return False
186
186
187 def version(self, *p, **k):
187 def version(self, *p, **k):
188 """API to :attr:`migrate.versioning.version.Collection.version`"""
188 """API to :attr:`migrate.versioning.version.Collection.version`"""
189 return self.versions.version(*p, **k)
189 return self.versions.version(*p, **k)
190
190
191 @classmethod
191 @classmethod
192 def clear(cls):
192 def clear(cls):
193 # TODO: deletes repo
193 # TODO: deletes repo
194 super(Repository, cls).clear()
194 super(Repository, cls).clear()
195 version.Collection.clear()
195 version.Collection.clear()
196
196
197 def changeset(self, database, start, end=None):
197 def changeset(self, database, start, end=None):
198 """Create a changeset to migrate this database from ver. start to end/latest.
198 """Create a changeset to migrate this database from ver. start to end/latest.
199
199
200 :param database: name of database to generate changeset
200 :param database: name of database to generate changeset
201 :param start: version to start at
201 :param start: version to start at
202 :param end: version to end at (latest if None given)
202 :param end: version to end at (latest if None given)
203 :type database: string
203 :type database: string
204 :type start: int
204 :type start: int
205 :type end: int
205 :type end: int
206 :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
206 :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
207 """
207 """
208 start = version.VerNum(start)
208 start = version.VerNum(start)
209
209
210 if end is None:
210 if end is None:
211 end = self.latest
211 end = self.latest
212 else:
212 else:
213 end = version.VerNum(end)
213 end = version.VerNum(end)
214
214
215 if start <= end:
215 if start <= end:
216 step = 1
216 step = 1
217 range_mod = 1
217 range_mod = 1
218 op = 'upgrade'
218 op = 'upgrade'
219 else:
219 else:
220 step = -1
220 step = -1
221 range_mod = 0
221 range_mod = 0
222 op = 'downgrade'
222 op = 'downgrade'
223
223
224 versions = range(int(start) + range_mod, int(end) + range_mod, step)
224 versions = list(range(int(start) + range_mod, int(end) + range_mod, step))
225 changes = [self.version(v).script(database, op) for v in versions]
225 changes = [self.version(v).script(database, op) for v in versions]
226 ret = Changeset(start, step=step, *changes)
226 ret = Changeset(start, step=step, *changes)
227 return ret
227 return ret
228
228
229 @classmethod
229 @classmethod
230 def create_manage_file(cls, file_, **opts):
230 def create_manage_file(cls, file_, **opts):
231 """Create a project management script (manage.py)
231 """Create a project management script (manage.py)
232
232
233 :param file_: Destination file to be written
233 :param file_: Destination file to be written
234 :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
234 :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
235 """
235 """
236 mng_file = Template(opts.pop('templates_path', None))\
236 mng_file = Template(opts.pop('templates_path', None))\
237 .get_manage(theme=opts.pop('templates_theme', None))
237 .get_manage(theme=opts.pop('templates_theme', None))
238
238
239 with open(mng_file) as f:
239 with open(mng_file) as f:
240 tmpl = f.read()
240 tmpl = f.read()
241
241
242 with open(file_, 'w') as fd:
242 with open(file_, 'w') as fd:
243 fd.write(TempitaTemplate(tmpl).substitute(opts))
243 fd.write(TempitaTemplate(tmpl).substitute(opts))
@@ -1,221 +1,221 b''
1 """
1 """
2 Database schema version management.
2 Database schema version management.
3 """
3 """
4 import sys
4 import sys
5 import logging
5 import logging
6
6
7 from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
7 from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
8 create_engine)
8 create_engine)
9 from sqlalchemy.sql import and_
9 from sqlalchemy.sql import and_
10 from sqlalchemy import exc as sa_exceptions
10 from sqlalchemy import exc as sa_exceptions
11 from sqlalchemy.sql import bindparam
11 from sqlalchemy.sql import bindparam
12
12
13 from rhodecode.lib.dbmigrate.migrate import exceptions
13 from rhodecode.lib.dbmigrate.migrate import exceptions
14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_07
14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_07
15 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
15 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
16 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
16 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
17 from rhodecode.lib.dbmigrate.migrate.versioning.util import load_model
17 from rhodecode.lib.dbmigrate.migrate.versioning.util import load_model
18 from rhodecode.lib.dbmigrate.migrate.versioning.version import VerNum
18 from rhodecode.lib.dbmigrate.migrate.versioning.version import VerNum
19
19
20
20
21 log = logging.getLogger(__name__)
21 log = logging.getLogger(__name__)
22
22
23
23
24 class ControlledSchema(object):
24 class ControlledSchema(object):
25 """A database under version control"""
25 """A database under version control"""
26
26
27 def __init__(self, engine, repository):
27 def __init__(self, engine, repository):
28 if isinstance(repository, str):
28 if isinstance(repository, str):
29 repository = Repository(repository)
29 repository = Repository(repository)
30 self.engine = engine
30 self.engine = engine
31 self.repository = repository
31 self.repository = repository
32 self.meta = MetaData(engine)
32 self.meta = MetaData(engine)
33 self.load()
33 self.load()
34
34
35 def __eq__(self, other):
35 def __eq__(self, other):
36 """Compare two schemas by repositories and versions"""
36 """Compare two schemas by repositories and versions"""
37 return (self.repository is other.repository \
37 return (self.repository is other.repository \
38 and self.version == other.version)
38 and self.version == other.version)
39
39
40 def load(self):
40 def load(self):
41 """Load controlled schema version info from DB"""
41 """Load controlled schema version info from DB"""
42 tname = self.repository.version_table
42 tname = self.repository.version_table
43 try:
43 try:
44 if not hasattr(self, 'table') or self.table is None:
44 if not hasattr(self, 'table') or self.table is None:
45 self.table = Table(tname, self.meta, autoload=True)
45 self.table = Table(tname, self.meta, autoload=True)
46
46
47 result = self.engine.execute(self.table.select(
47 result = self.engine.execute(self.table.select(
48 self.table.c.repository_id == str(self.repository.id)))
48 self.table.c.repository_id == str(self.repository.id)))
49
49
50 data = list(result)[0]
50 data = list(result)[0]
51 except:
51 except:
52 cls, exc, tb = sys.exc_info()
52 cls, exc, tb = sys.exc_info()
53 raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
53 raise exceptions.DatabaseNotControlledError(exc.__str__()).with_traceback(tb)
54
54
55 self.version = data['version']
55 self.version = data['version']
56 return data
56 return data
57
57
58 def drop(self):
58 def drop(self):
59 """
59 """
60 Remove version control from a database.
60 Remove version control from a database.
61 """
61 """
62 if SQLA_07:
62 if SQLA_07:
63 try:
63 try:
64 self.table.drop()
64 self.table.drop()
65 except sa_exceptions.DatabaseError:
65 except sa_exceptions.DatabaseError:
66 raise exceptions.DatabaseNotControlledError(str(self.table))
66 raise exceptions.DatabaseNotControlledError(str(self.table))
67 else:
67 else:
68 try:
68 try:
69 self.table.drop()
69 self.table.drop()
70 except (sa_exceptions.SQLError):
70 except (sa_exceptions.SQLError):
71 raise exceptions.DatabaseNotControlledError(str(self.table))
71 raise exceptions.DatabaseNotControlledError(str(self.table))
72
72
73 def changeset(self, version=None):
73 def changeset(self, version=None):
74 """API to Changeset creation.
74 """API to Changeset creation.
75
75
76 Uses self.version for start version and engine.name
76 Uses self.version for start version and engine.name
77 to get database name.
77 to get database name.
78 """
78 """
79 database = self.engine.name
79 database = self.engine.name
80 start_ver = self.version
80 start_ver = self.version
81 changeset = self.repository.changeset(database, start_ver, version)
81 changeset = self.repository.changeset(database, start_ver, version)
82 return changeset
82 return changeset
83
83
84 def runchange(self, ver, change, step):
84 def runchange(self, ver, change, step):
85 startver = ver
85 startver = ver
86 endver = ver + step
86 endver = ver + step
87 # Current database version must be correct! Don't run if corrupt!
87 # Current database version must be correct! Don't run if corrupt!
88 if self.version != startver:
88 if self.version != startver:
89 raise exceptions.InvalidVersionError("%s is not %s" % \
89 raise exceptions.InvalidVersionError("%s is not %s" % \
90 (self.version, startver))
90 (self.version, startver))
91 # Run the change
91 # Run the change
92 change.run(self.engine, step)
92 change.run(self.engine, step)
93
93
94 # Update/refresh database version
94 # Update/refresh database version
95 self.update_repository_table(startver, endver)
95 self.update_repository_table(startver, endver)
96 self.load()
96 self.load()
97
97
98 def update_repository_table(self, startver, endver):
98 def update_repository_table(self, startver, endver):
99 """Update version_table with new information"""
99 """Update version_table with new information"""
100 update = self.table.update(and_(self.table.c.version == int(startver),
100 update = self.table.update(and_(self.table.c.version == int(startver),
101 self.table.c.repository_id == str(self.repository.id)))
101 self.table.c.repository_id == str(self.repository.id)))
102 self.engine.execute(update, version=int(endver))
102 self.engine.execute(update, version=int(endver))
103
103
104 def upgrade(self, version=None):
104 def upgrade(self, version=None):
105 """
105 """
106 Upgrade (or downgrade) to a specified version, or latest version.
106 Upgrade (or downgrade) to a specified version, or latest version.
107 """
107 """
108 changeset = self.changeset(version)
108 changeset = self.changeset(version)
109 for ver, change in changeset:
109 for ver, change in changeset:
110 self.runchange(ver, change, changeset.step)
110 self.runchange(ver, change, changeset.step)
111
111
112 def update_db_from_model(self, model):
112 def update_db_from_model(self, model):
113 """
113 """
114 Modify the database to match the structure of the current Python model.
114 Modify the database to match the structure of the current Python model.
115 """
115 """
116 model = load_model(model)
116 model = load_model(model)
117
117
118 diff = schemadiff.getDiffOfModelAgainstDatabase(
118 diff = schemadiff.getDiffOfModelAgainstDatabase(
119 model, self.engine, excludeTables=[self.repository.version_table]
119 model, self.engine, excludeTables=[self.repository.version_table]
120 )
120 )
121 genmodel.ModelGenerator(diff,self.engine).runB2A()
121 genmodel.ModelGenerator(diff,self.engine).runB2A()
122
122
123 self.update_repository_table(self.version, int(self.repository.latest))
123 self.update_repository_table(self.version, int(self.repository.latest))
124
124
125 self.load()
125 self.load()
126
126
127 @classmethod
127 @classmethod
128 def create(cls, engine, repository, version=None):
128 def create(cls, engine, repository, version=None):
129 """
129 """
130 Declare a database to be under a repository's version control.
130 Declare a database to be under a repository's version control.
131
131
132 :raises: :exc:`DatabaseAlreadyControlledError`
132 :raises: :exc:`DatabaseAlreadyControlledError`
133 :returns: :class:`ControlledSchema`
133 :returns: :class:`ControlledSchema`
134 """
134 """
135 # Confirm that the version # is valid: positive, integer,
135 # Confirm that the version # is valid: positive, integer,
136 # exists in repos
136 # exists in repos
137 if isinstance(repository, str):
137 if isinstance(repository, str):
138 repository = Repository(repository)
138 repository = Repository(repository)
139 version = cls._validate_version(repository, version)
139 version = cls._validate_version(repository, version)
140 table = cls._create_table_version(engine, repository, version)
140 table = cls._create_table_version(engine, repository, version)
141 # TODO: history table
141 # TODO: history table
142 # Load repository information and return
142 # Load repository information and return
143 return cls(engine, repository)
143 return cls(engine, repository)
144
144
145 @classmethod
145 @classmethod
146 def _validate_version(cls, repository, version):
146 def _validate_version(cls, repository, version):
147 """
147 """
148 Ensures this is a valid version number for this repository.
148 Ensures this is a valid version number for this repository.
149
149
150 :raises: :exc:`InvalidVersionError` if invalid
150 :raises: :exc:`InvalidVersionError` if invalid
151 :return: valid version number
151 :return: valid version number
152 """
152 """
153 if version is None:
153 if version is None:
154 version = 0
154 version = 0
155 try:
155 try:
156 version = VerNum(version) # raises valueerror
156 version = VerNum(version) # raises valueerror
157 if version < 0 or version > repository.latest:
157 if version < 0 or version > repository.latest:
158 raise ValueError()
158 raise ValueError()
159 except ValueError:
159 except ValueError:
160 raise exceptions.InvalidVersionError(version)
160 raise exceptions.InvalidVersionError(version)
161 return version
161 return version
162
162
163 @classmethod
163 @classmethod
164 def _create_table_version(cls, engine, repository, version):
164 def _create_table_version(cls, engine, repository, version):
165 """
165 """
166 Creates the versioning table in a database.
166 Creates the versioning table in a database.
167
167
168 :raises: :exc:`DatabaseAlreadyControlledError`
168 :raises: :exc:`DatabaseAlreadyControlledError`
169 """
169 """
170 # Create tables
170 # Create tables
171 tname = repository.version_table
171 tname = repository.version_table
172 meta = MetaData(engine)
172 meta = MetaData(engine)
173
173
174 table = Table(
174 table = Table(
175 tname, meta,
175 tname, meta,
176 Column('repository_id', String(250), primary_key=True),
176 Column('repository_id', String(250), primary_key=True),
177 Column('repository_path', Text),
177 Column('repository_path', Text),
178 Column('version', Integer), )
178 Column('version', Integer), )
179
179
180 # there can be multiple repositories/schemas in the same db
180 # there can be multiple repositories/schemas in the same db
181 if not table.exists():
181 if not table.exists():
182 table.create()
182 table.create()
183
183
184 # test for existing repository_id
184 # test for existing repository_id
185 s = table.select(table.c.repository_id == bindparam("repository_id"))
185 s = table.select(table.c.repository_id == bindparam("repository_id"))
186 result = engine.execute(s, repository_id=repository.id)
186 result = engine.execute(s, repository_id=repository.id)
187 if result.fetchone():
187 if result.fetchone():
188 raise exceptions.DatabaseAlreadyControlledError
188 raise exceptions.DatabaseAlreadyControlledError
189
189
190 # Insert data
190 # Insert data
191 engine.execute(table.insert().values(
191 engine.execute(table.insert().values(
192 repository_id=repository.id,
192 repository_id=repository.id,
193 repository_path=repository.path,
193 repository_path=repository.path,
194 version=int(version)))
194 version=int(version)))
195 return table
195 return table
196
196
197 @classmethod
197 @classmethod
198 def compare_model_to_db(cls, engine, model, repository):
198 def compare_model_to_db(cls, engine, model, repository):
199 """
199 """
200 Compare the current model against the current database.
200 Compare the current model against the current database.
201 """
201 """
202 if isinstance(repository, str):
202 if isinstance(repository, str):
203 repository = Repository(repository)
203 repository = Repository(repository)
204 model = load_model(model)
204 model = load_model(model)
205
205
206 diff = schemadiff.getDiffOfModelAgainstDatabase(
206 diff = schemadiff.getDiffOfModelAgainstDatabase(
207 model, engine, excludeTables=[repository.version_table])
207 model, engine, excludeTables=[repository.version_table])
208 return diff
208 return diff
209
209
210 @classmethod
210 @classmethod
211 def create_model(cls, engine, repository, declarative=False):
211 def create_model(cls, engine, repository, declarative=False):
212 """
212 """
213 Dump the current database as a Python model.
213 Dump the current database as a Python model.
214 """
214 """
215 if isinstance(repository, str):
215 if isinstance(repository, str):
216 repository = Repository(repository)
216 repository = Repository(repository)
217
217
218 diff = schemadiff.getDiffOfModelAgainstDatabase(
218 diff = schemadiff.getDiffOfModelAgainstDatabase(
219 MetaData(), engine, excludeTables=[repository.version_table]
219 MetaData(), engine, excludeTables=[repository.version_table]
220 )
220 )
221 return genmodel.ModelGenerator(diff, engine, declarative).genBDefinition()
221 return genmodel.ModelGenerator(diff, engine, declarative).genBDefinition()
@@ -1,299 +1,299 b''
1 """
1 """
2 Schema differencing support.
2 Schema differencing support.
3 """
3 """
4
4
5 import logging
5 import logging
6 import sqlalchemy
6 import sqlalchemy
7
7
8 from sqlalchemy.types import Float
8 from sqlalchemy.types import Float
9
9
10 log = logging.getLogger(__name__)
10 log = logging.getLogger(__name__)
11
11
12
12
13 def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None):
13 def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None):
14 """
14 """
15 Return differences of model against database.
15 Return differences of model against database.
16
16
17 :return: object which will evaluate to :keyword:`True` if there \
17 :return: object which will evaluate to :keyword:`True` if there \
18 are differences else :keyword:`False`.
18 are differences else :keyword:`False`.
19 """
19 """
20 db_metadata = sqlalchemy.MetaData(engine)
20 db_metadata = sqlalchemy.MetaData(engine)
21 db_metadata.reflect()
21 db_metadata.reflect()
22
22
23 # sqlite will include a dynamically generated 'sqlite_sequence' table if
23 # sqlite will include a dynamically generated 'sqlite_sequence' table if
24 # there are autoincrement sequences in the database; this should not be
24 # there are autoincrement sequences in the database; this should not be
25 # compared.
25 # compared.
26 if engine.dialect.name == 'sqlite':
26 if engine.dialect.name == 'sqlite':
27 if 'sqlite_sequence' in db_metadata.tables:
27 if 'sqlite_sequence' in db_metadata.tables:
28 db_metadata.remove(db_metadata.tables['sqlite_sequence'])
28 db_metadata.remove(db_metadata.tables['sqlite_sequence'])
29
29
30 return SchemaDiff(metadata, db_metadata,
30 return SchemaDiff(metadata, db_metadata,
31 labelA='model',
31 labelA='model',
32 labelB='database',
32 labelB='database',
33 excludeTables=excludeTables)
33 excludeTables=excludeTables)
34
34
35
35
36 def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None):
36 def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None):
37 """
37 """
38 Return differences of model against another model.
38 Return differences of model against another model.
39
39
40 :return: object which will evaluate to :keyword:`True` if there \
40 :return: object which will evaluate to :keyword:`True` if there \
41 are differences else :keyword:`False`.
41 are differences else :keyword:`False`.
42 """
42 """
43 return SchemaDiff(metadataA, metadataB, excludeTables=excludeTables)
43 return SchemaDiff(metadataA, metadataB, excludeTables=excludeTables)
44
44
45
45
46 class ColDiff(object):
46 class ColDiff(object):
47 """
47 """
48 Container for differences in one :class:`~sqlalchemy.schema.Column`
48 Container for differences in one :class:`~sqlalchemy.schema.Column`
49 between two :class:`~sqlalchemy.schema.Table` instances, ``A``
49 between two :class:`~sqlalchemy.schema.Table` instances, ``A``
50 and ``B``.
50 and ``B``.
51
51
52 .. attribute:: col_A
52 .. attribute:: col_A
53
53
54 The :class:`~sqlalchemy.schema.Column` object for A.
54 The :class:`~sqlalchemy.schema.Column` object for A.
55
55
56 .. attribute:: col_B
56 .. attribute:: col_B
57
57
58 The :class:`~sqlalchemy.schema.Column` object for B.
58 The :class:`~sqlalchemy.schema.Column` object for B.
59
59
60 .. attribute:: type_A
60 .. attribute:: type_A
61
61
62 The most generic type of the :class:`~sqlalchemy.schema.Column`
62 The most generic type of the :class:`~sqlalchemy.schema.Column`
63 object in A.
63 object in A.
64
64
65 .. attribute:: type_B
65 .. attribute:: type_B
66
66
67 The most generic type of the :class:`~sqlalchemy.schema.Column`
67 The most generic type of the :class:`~sqlalchemy.schema.Column`
68 object in A.
68 object in A.
69
69
70 """
70 """
71
71
72 diff = False
72 diff = False
73
73
74 def __init__(self,col_A,col_B):
74 def __init__(self,col_A,col_B):
75 self.col_A = col_A
75 self.col_A = col_A
76 self.col_B = col_B
76 self.col_B = col_B
77
77
78 self.type_A = col_A.type
78 self.type_A = col_A.type
79 self.type_B = col_B.type
79 self.type_B = col_B.type
80
80
81 self.affinity_A = self.type_A._type_affinity
81 self.affinity_A = self.type_A._type_affinity
82 self.affinity_B = self.type_B._type_affinity
82 self.affinity_B = self.type_B._type_affinity
83
83
84 if self.affinity_A is not self.affinity_B:
84 if self.affinity_A is not self.affinity_B:
85 self.diff = True
85 self.diff = True
86 return
86 return
87
87
88 if isinstance(self.type_A,Float) or isinstance(self.type_B,Float):
88 if isinstance(self.type_A,Float) or isinstance(self.type_B,Float):
89 if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)):
89 if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)):
90 self.diff=True
90 self.diff=True
91 return
91 return
92
92
93 for attr in ('precision','scale','length'):
93 for attr in ('precision','scale','length'):
94 A = getattr(self.type_A,attr,None)
94 A = getattr(self.type_A,attr,None)
95 B = getattr(self.type_B,attr,None)
95 B = getattr(self.type_B,attr,None)
96 if not (A is None or B is None) and A!=B:
96 if not (A is None or B is None) and A!=B:
97 self.diff=True
97 self.diff=True
98 return
98 return
99
99
100 def __nonzero__(self):
100 def __bool__(self):
101 return self.diff
101 return self.diff
102
102
103 __bool__ = __nonzero__
103 __bool__ = __nonzero__
104
104
105
105
106 class TableDiff(object):
106 class TableDiff(object):
107 """
107 """
108 Container for differences in one :class:`~sqlalchemy.schema.Table`
108 Container for differences in one :class:`~sqlalchemy.schema.Table`
109 between two :class:`~sqlalchemy.schema.MetaData` instances, ``A``
109 between two :class:`~sqlalchemy.schema.MetaData` instances, ``A``
110 and ``B``.
110 and ``B``.
111
111
112 .. attribute:: columns_missing_from_A
112 .. attribute:: columns_missing_from_A
113
113
114 A sequence of column names that were found in B but weren't in
114 A sequence of column names that were found in B but weren't in
115 A.
115 A.
116
116
117 .. attribute:: columns_missing_from_B
117 .. attribute:: columns_missing_from_B
118
118
119 A sequence of column names that were found in A but weren't in
119 A sequence of column names that were found in A but weren't in
120 B.
120 B.
121
121
122 .. attribute:: columns_different
122 .. attribute:: columns_different
123
123
124 A dictionary containing information about columns that were
124 A dictionary containing information about columns that were
125 found to be different.
125 found to be different.
126 It maps column names to a :class:`ColDiff` objects describing the
126 It maps column names to a :class:`ColDiff` objects describing the
127 differences found.
127 differences found.
128 """
128 """
129 __slots__ = (
129 __slots__ = (
130 'columns_missing_from_A',
130 'columns_missing_from_A',
131 'columns_missing_from_B',
131 'columns_missing_from_B',
132 'columns_different',
132 'columns_different',
133 )
133 )
134
134
135 def __nonzero__(self):
135 def __bool__(self):
136 return bool(
136 return bool(
137 self.columns_missing_from_A or
137 self.columns_missing_from_A or
138 self.columns_missing_from_B or
138 self.columns_missing_from_B or
139 self.columns_different
139 self.columns_different
140 )
140 )
141
141
142 __bool__ = __nonzero__
142 __bool__ = __nonzero__
143
143
144 class SchemaDiff(object):
144 class SchemaDiff(object):
145 """
145 """
146 Compute the difference between two :class:`~sqlalchemy.schema.MetaData`
146 Compute the difference between two :class:`~sqlalchemy.schema.MetaData`
147 objects.
147 objects.
148
148
149 The string representation of a :class:`SchemaDiff` will summarise
149 The string representation of a :class:`SchemaDiff` will summarise
150 the changes found between the two
150 the changes found between the two
151 :class:`~sqlalchemy.schema.MetaData` objects.
151 :class:`~sqlalchemy.schema.MetaData` objects.
152
152
153 The length of a :class:`SchemaDiff` will give the number of
153 The length of a :class:`SchemaDiff` will give the number of
154 changes found, enabling it to be used much like a boolean in
154 changes found, enabling it to be used much like a boolean in
155 expressions.
155 expressions.
156
156
157 :param metadataA:
157 :param metadataA:
158 First :class:`~sqlalchemy.schema.MetaData` to compare.
158 First :class:`~sqlalchemy.schema.MetaData` to compare.
159
159
160 :param metadataB:
160 :param metadataB:
161 Second :class:`~sqlalchemy.schema.MetaData` to compare.
161 Second :class:`~sqlalchemy.schema.MetaData` to compare.
162
162
163 :param labelA:
163 :param labelA:
164 The label to use in messages about the first
164 The label to use in messages about the first
165 :class:`~sqlalchemy.schema.MetaData`.
165 :class:`~sqlalchemy.schema.MetaData`.
166
166
167 :param labelB:
167 :param labelB:
168 The label to use in messages about the second
168 The label to use in messages about the second
169 :class:`~sqlalchemy.schema.MetaData`.
169 :class:`~sqlalchemy.schema.MetaData`.
170
170
171 :param excludeTables:
171 :param excludeTables:
172 A sequence of table names to exclude.
172 A sequence of table names to exclude.
173
173
174 .. attribute:: tables_missing_from_A
174 .. attribute:: tables_missing_from_A
175
175
176 A sequence of table names that were found in B but weren't in
176 A sequence of table names that were found in B but weren't in
177 A.
177 A.
178
178
179 .. attribute:: tables_missing_from_B
179 .. attribute:: tables_missing_from_B
180
180
181 A sequence of table names that were found in A but weren't in
181 A sequence of table names that were found in A but weren't in
182 B.
182 B.
183
183
184 .. attribute:: tables_different
184 .. attribute:: tables_different
185
185
186 A dictionary containing information about tables that were found
186 A dictionary containing information about tables that were found
187 to be different.
187 to be different.
188 It maps table names to a :class:`TableDiff` objects describing the
188 It maps table names to a :class:`TableDiff` objects describing the
189 differences found.
189 differences found.
190 """
190 """
191
191
192 def __init__(self,
192 def __init__(self,
193 metadataA, metadataB,
193 metadataA, metadataB,
194 labelA='metadataA',
194 labelA='metadataA',
195 labelB='metadataB',
195 labelB='metadataB',
196 excludeTables=None):
196 excludeTables=None):
197
197
198 self.metadataA, self.metadataB = metadataA, metadataB
198 self.metadataA, self.metadataB = metadataA, metadataB
199 self.labelA, self.labelB = labelA, labelB
199 self.labelA, self.labelB = labelA, labelB
200 self.label_width = max(len(labelA),len(labelB))
200 self.label_width = max(len(labelA),len(labelB))
201 excludeTables = set(excludeTables or [])
201 excludeTables = set(excludeTables or [])
202
202
203 A_table_names = set(metadataA.tables.keys())
203 A_table_names = set(metadataA.tables.keys())
204 B_table_names = set(metadataB.tables.keys())
204 B_table_names = set(metadataB.tables.keys())
205
205
206 self.tables_missing_from_A = sorted(
206 self.tables_missing_from_A = sorted(
207 B_table_names - A_table_names - excludeTables
207 B_table_names - A_table_names - excludeTables
208 )
208 )
209 self.tables_missing_from_B = sorted(
209 self.tables_missing_from_B = sorted(
210 A_table_names - B_table_names - excludeTables
210 A_table_names - B_table_names - excludeTables
211 )
211 )
212
212
213 self.tables_different = {}
213 self.tables_different = {}
214 for table_name in A_table_names.intersection(B_table_names):
214 for table_name in A_table_names.intersection(B_table_names):
215
215
216 td = TableDiff()
216 td = TableDiff()
217
217
218 A_table = metadataA.tables[table_name]
218 A_table = metadataA.tables[table_name]
219 B_table = metadataB.tables[table_name]
219 B_table = metadataB.tables[table_name]
220
220
221 A_column_names = set(A_table.columns.keys())
221 A_column_names = set(A_table.columns.keys())
222 B_column_names = set(B_table.columns.keys())
222 B_column_names = set(B_table.columns.keys())
223
223
224 td.columns_missing_from_A = sorted(
224 td.columns_missing_from_A = sorted(
225 B_column_names - A_column_names
225 B_column_names - A_column_names
226 )
226 )
227
227
228 td.columns_missing_from_B = sorted(
228 td.columns_missing_from_B = sorted(
229 A_column_names - B_column_names
229 A_column_names - B_column_names
230 )
230 )
231
231
232 td.columns_different = {}
232 td.columns_different = {}
233
233
234 for col_name in A_column_names.intersection(B_column_names):
234 for col_name in A_column_names.intersection(B_column_names):
235
235
236 cd = ColDiff(
236 cd = ColDiff(
237 A_table.columns.get(col_name),
237 A_table.columns.get(col_name),
238 B_table.columns.get(col_name)
238 B_table.columns.get(col_name)
239 )
239 )
240
240
241 if cd:
241 if cd:
242 td.columns_different[col_name]=cd
242 td.columns_different[col_name]=cd
243
243
244 # XXX - index and constraint differences should
244 # XXX - index and constraint differences should
245 # be checked for here
245 # be checked for here
246
246
247 if td:
247 if td:
248 self.tables_different[table_name]=td
248 self.tables_different[table_name]=td
249
249
250 def __str__(self):
250 def __str__(self):
251 """ Summarize differences. """
251 """ Summarize differences. """
252 out = []
252 out = []
253 column_template =' %%%is: %%r' % self.label_width
253 column_template =' %%%is: %%r' % self.label_width
254
254
255 for names,label in (
255 for names,label in (
256 (self.tables_missing_from_A,self.labelA),
256 (self.tables_missing_from_A,self.labelA),
257 (self.tables_missing_from_B,self.labelB),
257 (self.tables_missing_from_B,self.labelB),
258 ):
258 ):
259 if names:
259 if names:
260 out.append(
260 out.append(
261 ' tables missing from %s: %s' % (
261 ' tables missing from %s: %s' % (
262 label,', '.join(sorted(names))
262 label,', '.join(sorted(names))
263 )
263 )
264 )
264 )
265
265
266 for name,td in sorted(self.tables_different.items()):
266 for name,td in sorted(self.tables_different.items()):
267 out.append(
267 out.append(
268 ' table with differences: %s' % name
268 ' table with differences: %s' % name
269 )
269 )
270 for names,label in (
270 for names,label in (
271 (td.columns_missing_from_A,self.labelA),
271 (td.columns_missing_from_A,self.labelA),
272 (td.columns_missing_from_B,self.labelB),
272 (td.columns_missing_from_B,self.labelB),
273 ):
273 ):
274 if names:
274 if names:
275 out.append(
275 out.append(
276 ' %s missing these columns: %s' % (
276 ' %s missing these columns: %s' % (
277 label,', '.join(sorted(names))
277 label,', '.join(sorted(names))
278 )
278 )
279 )
279 )
280 for name,cd in td.columns_different.items():
280 for name,cd in list(td.columns_different.items()):
281 out.append(' column with differences: %s' % name)
281 out.append(' column with differences: %s' % name)
282 out.append(column_template % (self.labelA,cd.col_A))
282 out.append(column_template % (self.labelA,cd.col_A))
283 out.append(column_template % (self.labelB,cd.col_B))
283 out.append(column_template % (self.labelB,cd.col_B))
284
284
285 if out:
285 if out:
286 out.insert(0, 'Schema diffs:')
286 out.insert(0, 'Schema diffs:')
287 return '\n'.join(out)
287 return '\n'.join(out)
288 else:
288 else:
289 return 'No schema diffs'
289 return 'No schema diffs'
290
290
291 def __len__(self):
291 def __len__(self):
292 """
292 """
293 Used in bool evaluation, return of 0 means no diffs.
293 Used in bool evaluation, return of 0 means no diffs.
294 """
294 """
295 return (
295 return (
296 len(self.tables_missing_from_A) +
296 len(self.tables_missing_from_A) +
297 len(self.tables_missing_from_B) +
297 len(self.tables_missing_from_B) +
298 len(self.tables_different)
298 len(self.tables_different)
299 )
299 )
@@ -1,215 +1,215 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3
3
4 """The migrate command-line tool."""
4 """The migrate command-line tool."""
5
5
6 import sys
6 import sys
7 import inspect
7 import inspect
8 import logging
8 import logging
9 from optparse import OptionParser, BadOptionError
9 from optparse import OptionParser, BadOptionError
10
10
11 from rhodecode.lib.dbmigrate.migrate import exceptions
11 from rhodecode.lib.dbmigrate.migrate import exceptions
12 from rhodecode.lib.dbmigrate.migrate.versioning import api
12 from rhodecode.lib.dbmigrate.migrate.versioning import api
13 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
13 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
14 from rhodecode.lib.dbmigrate.migrate.versioning.util import asbool
14 from rhodecode.lib.dbmigrate.migrate.versioning.util import asbool
15
15
16
16
17 alias = {
17 alias = {
18 's': api.script,
18 's': api.script,
19 'vc': api.version_control,
19 'vc': api.version_control,
20 'dbv': api.db_version,
20 'dbv': api.db_version,
21 'v': api.version,
21 'v': api.version,
22 }
22 }
23
23
24 def alias_setup():
24 def alias_setup():
25 global alias
25 global alias
26 for key, val in alias.items():
26 for key, val in list(alias.items()):
27 setattr(api, key, val)
27 setattr(api, key, val)
28 alias_setup()
28 alias_setup()
29
29
30
30
31 class PassiveOptionParser(OptionParser):
31 class PassiveOptionParser(OptionParser):
32
32
33 def _process_args(self, largs, rargs, values):
33 def _process_args(self, largs, rargs, values):
34 """little hack to support all --some_option=value parameters"""
34 """little hack to support all --some_option=value parameters"""
35
35
36 while rargs:
36 while rargs:
37 arg = rargs[0]
37 arg = rargs[0]
38 if arg == "--":
38 if arg == "--":
39 del rargs[0]
39 del rargs[0]
40 return
40 return
41 elif arg[0:2] == "--":
41 elif arg[0:2] == "--":
42 # if parser does not know about the option
42 # if parser does not know about the option
43 # pass it along (make it anonymous)
43 # pass it along (make it anonymous)
44 try:
44 try:
45 opt = arg.split('=', 1)[0]
45 opt = arg.split('=', 1)[0]
46 self._match_long_opt(opt)
46 self._match_long_opt(opt)
47 except BadOptionError:
47 except BadOptionError:
48 largs.append(arg)
48 largs.append(arg)
49 del rargs[0]
49 del rargs[0]
50 else:
50 else:
51 self._process_long_opt(rargs, values)
51 self._process_long_opt(rargs, values)
52 elif arg[:1] == "-" and len(arg) > 1:
52 elif arg[:1] == "-" and len(arg) > 1:
53 self._process_short_opts(rargs, values)
53 self._process_short_opts(rargs, values)
54 elif self.allow_interspersed_args:
54 elif self.allow_interspersed_args:
55 largs.append(arg)
55 largs.append(arg)
56 del rargs[0]
56 del rargs[0]
57
57
58 def main(argv=None, **kwargs):
58 def main(argv=None, **kwargs):
59 """Shell interface to :mod:`migrate.versioning.api`.
59 """Shell interface to :mod:`migrate.versioning.api`.
60
60
61 kwargs are default options that can be overriden with passing
61 kwargs are default options that can be overriden with passing
62 --some_option as command line option
62 --some_option as command line option
63
63
64 :param disable_logging: Let migrate configure logging
64 :param disable_logging: Let migrate configure logging
65 :type disable_logging: bool
65 :type disable_logging: bool
66 """
66 """
67 if argv is not None:
67 if argv is not None:
68 argv = argv
68 argv = argv
69 else:
69 else:
70 argv = list(sys.argv[1:])
70 argv = list(sys.argv[1:])
71 commands = list(api.__all__)
71 commands = list(api.__all__)
72 commands.sort()
72 commands.sort()
73
73
74 usage = """%%prog COMMAND ...
74 usage = """%%prog COMMAND ...
75
75
76 Available commands:
76 Available commands:
77 %s
77 %s
78
78
79 Enter "%%prog help COMMAND" for information on a particular command.
79 Enter "%%prog help COMMAND" for information on a particular command.
80 """ % '\n\t'.join(["%s - %s" % (command.ljust(28), api.command_desc.get(command)) for command in commands])
80 """ % '\n\t'.join(["%s - %s" % (command.ljust(28), api.command_desc.get(command)) for command in commands])
81
81
82 parser = PassiveOptionParser(usage=usage)
82 parser = PassiveOptionParser(usage=usage)
83 parser.add_option("-d", "--debug",
83 parser.add_option("-d", "--debug",
84 action="store_true",
84 action="store_true",
85 dest="debug",
85 dest="debug",
86 default=False,
86 default=False,
87 help="Shortcut to turn on DEBUG mode for logging")
87 help="Shortcut to turn on DEBUG mode for logging")
88 parser.add_option("-q", "--disable_logging",
88 parser.add_option("-q", "--disable_logging",
89 action="store_true",
89 action="store_true",
90 dest="disable_logging",
90 dest="disable_logging",
91 default=False,
91 default=False,
92 help="Use this option to disable logging configuration")
92 help="Use this option to disable logging configuration")
93 help_commands = ['help', '-h', '--help']
93 help_commands = ['help', '-h', '--help']
94 HELP = False
94 HELP = False
95
95
96 try:
96 try:
97 command = argv.pop(0)
97 command = argv.pop(0)
98 if command in help_commands:
98 if command in help_commands:
99 HELP = True
99 HELP = True
100 command = argv.pop(0)
100 command = argv.pop(0)
101 except IndexError:
101 except IndexError:
102 parser.print_help()
102 parser.print_help()
103 return
103 return
104
104
105 command_func = getattr(api, command, None)
105 command_func = getattr(api, command, None)
106 if command_func is None or command.startswith('_'):
106 if command_func is None or command.startswith('_'):
107 parser.error("Invalid command %s" % command)
107 parser.error("Invalid command %s" % command)
108
108
109 parser.set_usage(inspect.getdoc(command_func))
109 parser.set_usage(inspect.getdoc(command_func))
110 f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
110 f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
111 for arg in f_args:
111 for arg in f_args:
112 parser.add_option(
112 parser.add_option(
113 "--%s" % arg,
113 "--%s" % arg,
114 dest=arg,
114 dest=arg,
115 action='store',
115 action='store',
116 type="string")
116 type="string")
117
117
118 # display help of the current command
118 # display help of the current command
119 if HELP:
119 if HELP:
120 parser.print_help()
120 parser.print_help()
121 return
121 return
122
122
123 options, args = parser.parse_args(argv)
123 options, args = parser.parse_args(argv)
124
124
125 # override kwargs with anonymous parameters
125 # override kwargs with anonymous parameters
126 override_kwargs = {}
126 override_kwargs = {}
127 for arg in list(args):
127 for arg in list(args):
128 if arg.startswith('--'):
128 if arg.startswith('--'):
129 args.remove(arg)
129 args.remove(arg)
130 if '=' in arg:
130 if '=' in arg:
131 opt, value = arg[2:].split('=', 1)
131 opt, value = arg[2:].split('=', 1)
132 else:
132 else:
133 opt = arg[2:]
133 opt = arg[2:]
134 value = True
134 value = True
135 override_kwargs[opt] = value
135 override_kwargs[opt] = value
136
136
137 # override kwargs with options if user is overwriting
137 # override kwargs with options if user is overwriting
138 for key, value in options.__dict__.items():
138 for key, value in list(options.__dict__.items()):
139 if value is not None:
139 if value is not None:
140 override_kwargs[key] = value
140 override_kwargs[key] = value
141
141
142 # arguments that function accepts without passed kwargs
142 # arguments that function accepts without passed kwargs
143 f_required = list(f_args)
143 f_required = list(f_args)
144 candidates = dict(kwargs)
144 candidates = dict(kwargs)
145 candidates.update(override_kwargs)
145 candidates.update(override_kwargs)
146 for key, value in candidates.items():
146 for key, value in list(candidates.items()):
147 if key in f_args:
147 if key in f_args:
148 f_required.remove(key)
148 f_required.remove(key)
149
149
150 # map function arguments to parsed arguments
150 # map function arguments to parsed arguments
151 for arg in args:
151 for arg in args:
152 try:
152 try:
153 kw = f_required.pop(0)
153 kw = f_required.pop(0)
154 except IndexError:
154 except IndexError:
155 parser.error("Too many arguments for command %s: %s" % (command,
155 parser.error("Too many arguments for command %s: %s" % (command,
156 arg))
156 arg))
157 kwargs[kw] = arg
157 kwargs[kw] = arg
158
158
159 # apply overrides
159 # apply overrides
160 kwargs.update(override_kwargs)
160 kwargs.update(override_kwargs)
161
161
162 # configure options
162 # configure options
163 for key, value in options.__dict__.items():
163 for key, value in list(options.__dict__.items()):
164 kwargs.setdefault(key, value)
164 kwargs.setdefault(key, value)
165
165
166 # configure logging
166 # configure logging
167 if not asbool(kwargs.pop('disable_logging', False)):
167 if not asbool(kwargs.pop('disable_logging', False)):
168 # filter to log =< INFO into stdout and rest to stderr
168 # filter to log =< INFO into stdout and rest to stderr
169 class SingleLevelFilter(logging.Filter):
169 class SingleLevelFilter(logging.Filter):
170 def __init__(self, min=None, max=None):
170 def __init__(self, min=None, max=None):
171 self.min = min or 0
171 self.min = min or 0
172 self.max = max or 100
172 self.max = max or 100
173
173
174 def filter(self, record):
174 def filter(self, record):
175 return self.min <= record.levelno <= self.max
175 return self.min <= record.levelno <= self.max
176
176
177 logger = logging.getLogger()
177 logger = logging.getLogger()
178 h1 = logging.StreamHandler(sys.stdout)
178 h1 = logging.StreamHandler(sys.stdout)
179 f1 = SingleLevelFilter(max=logging.INFO)
179 f1 = SingleLevelFilter(max=logging.INFO)
180 h1.addFilter(f1)
180 h1.addFilter(f1)
181 h2 = logging.StreamHandler(sys.stderr)
181 h2 = logging.StreamHandler(sys.stderr)
182 f2 = SingleLevelFilter(min=logging.WARN)
182 f2 = SingleLevelFilter(min=logging.WARN)
183 h2.addFilter(f2)
183 h2.addFilter(f2)
184 logger.addHandler(h1)
184 logger.addHandler(h1)
185 logger.addHandler(h2)
185 logger.addHandler(h2)
186
186
187 if options.debug:
187 if options.debug:
188 logger.setLevel(logging.DEBUG)
188 logger.setLevel(logging.DEBUG)
189 else:
189 else:
190 logger.setLevel(logging.INFO)
190 logger.setLevel(logging.INFO)
191
191
192 log = logging.getLogger(__name__)
192 log = logging.getLogger(__name__)
193
193
194 # check if all args are given
194 # check if all args are given
195 try:
195 try:
196 num_defaults = len(f_defaults)
196 num_defaults = len(f_defaults)
197 except TypeError:
197 except TypeError:
198 num_defaults = 0
198 num_defaults = 0
199 f_args_default = f_args[len(f_args) - num_defaults:]
199 f_args_default = f_args[len(f_args) - num_defaults:]
200 required = list(set(f_required) - set(f_args_default))
200 required = list(set(f_required) - set(f_args_default))
201 required.sort()
201 required.sort()
202 if required:
202 if required:
203 parser.error("Not enough arguments for command %s: %s not specified" \
203 parser.error("Not enough arguments for command %s: %s not specified" \
204 % (command, ', '.join(required)))
204 % (command, ', '.join(required)))
205
205
206 # handle command
206 # handle command
207 try:
207 try:
208 ret = command_func(**kwargs)
208 ret = command_func(**kwargs)
209 if ret is not None:
209 if ret is not None:
210 log.info(ret)
210 log.info(ret)
211 except (exceptions.UsageError, exceptions.KnownError) as e:
211 except (exceptions.UsageError, exceptions.KnownError) as e:
212 parser.error(e.args[0])
212 parser.error(e.args[0])
213
213
214 if __name__ == "__main__":
214 if __name__ == "__main__":
215 main()
215 main()
@@ -1,180 +1,180 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3 """.. currentmodule:: migrate.versioning.util"""
3 """.. currentmodule:: migrate.versioning.util"""
4
4
5 import warnings
5 import warnings
6 import logging
6 import logging
7 from decorator import decorator
7 from decorator import decorator
8 from pkg_resources import EntryPoint
8 from pkg_resources import EntryPoint
9
9
10 from sqlalchemy import create_engine
10 from sqlalchemy import create_engine
11 from sqlalchemy.engine import Engine
11 from sqlalchemy.engine import Engine
12 from sqlalchemy.pool import StaticPool
12 from sqlalchemy.pool import StaticPool
13
13
14 from rhodecode.lib.dbmigrate.migrate import exceptions
14 from rhodecode.lib.dbmigrate.migrate import exceptions
15 from rhodecode.lib.dbmigrate.migrate.versioning.util.keyedinstance import KeyedInstance
15 from rhodecode.lib.dbmigrate.migrate.versioning.util.keyedinstance import KeyedInstance
16 from rhodecode.lib.dbmigrate.migrate.versioning.util.importpath import import_path
16 from rhodecode.lib.dbmigrate.migrate.versioning.util.importpath import import_path
17
17
18
18
19 log = logging.getLogger(__name__)
19 log = logging.getLogger(__name__)
20
20
21
21
22 def load_model(dotted_name):
22 def load_model(dotted_name):
23 """Import module and use module-level variable".
23 """Import module and use module-level variable".
24
24
25 :param dotted_name: path to model in form of string: ``some.python.module:Class``
25 :param dotted_name: path to model in form of string: ``some.python.module:Class``
26
26
27 .. versionchanged:: 0.5.4
27 .. versionchanged:: 0.5.4
28
28
29 """
29 """
30 if isinstance(dotted_name, str):
30 if isinstance(dotted_name, str):
31 if ':' not in dotted_name:
31 if ':' not in dotted_name:
32 # backwards compatibility
32 # backwards compatibility
33 warnings.warn('model should be in form of module.model:User '
33 warnings.warn('model should be in form of module.model:User '
34 'and not module.model.User', exceptions.MigrateDeprecationWarning)
34 'and not module.model.User', exceptions.MigrateDeprecationWarning)
35 dotted_name = ':'.join(dotted_name.rsplit('.', 1))
35 dotted_name = ':'.join(dotted_name.rsplit('.', 1))
36 return EntryPoint.parse('x=%s' % dotted_name).load(False)
36 return EntryPoint.parse('x=%s' % dotted_name).load(False)
37 else:
37 else:
38 # Assume it's already loaded.
38 # Assume it's already loaded.
39 return dotted_name
39 return dotted_name
40
40
41 def asbool(obj):
41 def asbool(obj):
42 """Do everything to use object as bool"""
42 """Do everything to use object as bool"""
43 if isinstance(obj, str):
43 if isinstance(obj, str):
44 obj = obj.strip().lower()
44 obj = obj.strip().lower()
45 if obj in ['true', 'yes', 'on', 'y', 't', '1']:
45 if obj in ['true', 'yes', 'on', 'y', 't', '1']:
46 return True
46 return True
47 elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
47 elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
48 return False
48 return False
49 else:
49 else:
50 raise ValueError("String is not true/false: %r" % obj)
50 raise ValueError("String is not true/false: %r" % obj)
51 if obj in (True, False):
51 if obj in (True, False):
52 return bool(obj)
52 return bool(obj)
53 else:
53 else:
54 raise ValueError("String is not true/false: %r" % obj)
54 raise ValueError("String is not true/false: %r" % obj)
55
55
56 def guess_obj_type(obj):
56 def guess_obj_type(obj):
57 """Do everything to guess object type from string
57 """Do everything to guess object type from string
58
58
59 Tries to convert to `int`, `bool` and finally returns if not succeded.
59 Tries to convert to `int`, `bool` and finally returns if not succeded.
60
60
61 .. versionadded: 0.5.4
61 .. versionadded: 0.5.4
62 """
62 """
63
63
64 result = None
64 result = None
65
65
66 try:
66 try:
67 result = int(obj)
67 result = int(obj)
68 except:
68 except:
69 pass
69 pass
70
70
71 if result is None:
71 if result is None:
72 try:
72 try:
73 result = asbool(obj)
73 result = asbool(obj)
74 except:
74 except:
75 pass
75 pass
76
76
77 if result is not None:
77 if result is not None:
78 return result
78 return result
79 else:
79 else:
80 return obj
80 return obj
81
81
82 @decorator
82 @decorator
83 def catch_known_errors(f, *a, **kw):
83 def catch_known_errors(f, *a, **kw):
84 """Decorator that catches known api errors
84 """Decorator that catches known api errors
85
85
86 .. versionadded: 0.5.4
86 .. versionadded: 0.5.4
87 """
87 """
88
88
89 try:
89 try:
90 return f(*a, **kw)
90 return f(*a, **kw)
91 except exceptions.PathFoundError as e:
91 except exceptions.PathFoundError as e:
92 raise exceptions.KnownError("The path %s already exists" % e.args[0])
92 raise exceptions.KnownError("The path %s already exists" % e.args[0])
93
93
94 def construct_engine(engine, **opts):
94 def construct_engine(engine, **opts):
95 """.. versionadded:: 0.5.4
95 """.. versionadded:: 0.5.4
96
96
97 Constructs and returns SQLAlchemy engine.
97 Constructs and returns SQLAlchemy engine.
98
98
99 Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions:
99 Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions:
100
100
101 :param engine: connection string or a existing engine
101 :param engine: connection string or a existing engine
102 :param engine_dict: python dictionary of options to pass to `create_engine`
102 :param engine_dict: python dictionary of options to pass to `create_engine`
103 :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`)
103 :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`)
104 :type engine_dict: dict
104 :type engine_dict: dict
105 :type engine: string or Engine instance
105 :type engine: string or Engine instance
106 :type engine_arg_*: string
106 :type engine_arg_*: string
107 :returns: SQLAlchemy Engine
107 :returns: SQLAlchemy Engine
108
108
109 .. note::
109 .. note::
110
110
111 keyword parameters override ``engine_dict`` values.
111 keyword parameters override ``engine_dict`` values.
112
112
113 """
113 """
114 if isinstance(engine, Engine):
114 if isinstance(engine, Engine):
115 return engine
115 return engine
116 elif not isinstance(engine, str):
116 elif not isinstance(engine, str):
117 raise ValueError("you need to pass either an existing engine or a database uri")
117 raise ValueError("you need to pass either an existing engine or a database uri")
118
118
119 # get options for create_engine
119 # get options for create_engine
120 if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
120 if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
121 kwargs = opts['engine_dict']
121 kwargs = opts['engine_dict']
122 else:
122 else:
123 kwargs = {}
123 kwargs = {}
124
124
125 # DEPRECATED: handle echo the old way
125 # DEPRECATED: handle echo the old way
126 echo = asbool(opts.get('echo', False))
126 echo = asbool(opts.get('echo', False))
127 if echo:
127 if echo:
128 warnings.warn('echo=True parameter is deprecated, pass '
128 warnings.warn('echo=True parameter is deprecated, pass '
129 'engine_arg_echo=True or engine_dict={"echo": True}',
129 'engine_arg_echo=True or engine_dict={"echo": True}',
130 exceptions.MigrateDeprecationWarning)
130 exceptions.MigrateDeprecationWarning)
131 kwargs['echo'] = echo
131 kwargs['echo'] = echo
132
132
133 # parse keyword arguments
133 # parse keyword arguments
134 for key, value in opts.items():
134 for key, value in list(opts.items()):
135 if key.startswith('engine_arg_'):
135 if key.startswith('engine_arg_'):
136 kwargs[key[11:]] = guess_obj_type(value)
136 kwargs[key[11:]] = guess_obj_type(value)
137
137
138 log.debug('Constructing engine')
138 log.debug('Constructing engine')
139 # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs)
139 # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs)
140 # seems like 0.5.x branch does not work with engine.dispose and staticpool
140 # seems like 0.5.x branch does not work with engine.dispose and staticpool
141 return create_engine(engine, **kwargs)
141 return create_engine(engine, **kwargs)
142
142
143 @decorator
143 @decorator
144 def with_engine(f, *a, **kw):
144 def with_engine(f, *a, **kw):
145 """Decorator for :mod:`migrate.versioning.api` functions
145 """Decorator for :mod:`migrate.versioning.api` functions
146 to safely close resources after function usage.
146 to safely close resources after function usage.
147
147
148 Passes engine parameters to :func:`construct_engine` and
148 Passes engine parameters to :func:`construct_engine` and
149 resulting parameter is available as kw['engine'].
149 resulting parameter is available as kw['engine'].
150
150
151 Engine is disposed after wrapped function is executed.
151 Engine is disposed after wrapped function is executed.
152
152
153 .. versionadded: 0.6.0
153 .. versionadded: 0.6.0
154 """
154 """
155 url = a[0]
155 url = a[0]
156 engine = construct_engine(url, **kw)
156 engine = construct_engine(url, **kw)
157
157
158 try:
158 try:
159 kw['engine'] = engine
159 kw['engine'] = engine
160 return f(*a, **kw)
160 return f(*a, **kw)
161 finally:
161 finally:
162 if isinstance(engine, Engine) and engine is not url:
162 if isinstance(engine, Engine) and engine is not url:
163 log.debug('Disposing SQLAlchemy engine %s', engine)
163 log.debug('Disposing SQLAlchemy engine %s', engine)
164 engine.dispose()
164 engine.dispose()
165
165
166
166
167 class Memoize:
167 class Memoize:
168 """Memoize(fn) - an instance which acts like fn but memoizes its arguments
168 """Memoize(fn) - an instance which acts like fn but memoizes its arguments
169 Will only work on functions with non-mutable arguments
169 Will only work on functions with non-mutable arguments
170
170
171 ActiveState Code 52201
171 ActiveState Code 52201
172 """
172 """
173 def __init__(self, fn):
173 def __init__(self, fn):
174 self.fn = fn
174 self.fn = fn
175 self.memo = {}
175 self.memo = {}
176
176
177 def __call__(self, *args):
177 def __call__(self, *args):
178 if args not in self.memo:
178 if args not in self.memo:
179 self.memo[args] = self.fn(*args)
179 self.memo[args] = self.fn(*args)
180 return self.memo[args]
180 return self.memo[args]
@@ -1,15 +1,16 b''
1 import os
1 import os
2 import sys
2 import sys
3 import importlib
3
4
4 def import_path(fullpath):
5 def import_path(fullpath):
5 """ Import a file with full path specification. Allows one to
6 """ Import a file with full path specification. Allows one to
6 import from anywhere, something __import__ does not do.
7 import from anywhere, something __import__ does not do.
7 """
8 """
8 # http://zephyrfalcon.org/weblog/arch_d7_2002_08_31.html
9 # http://zephyrfalcon.org/weblog/arch_d7_2002_08_31.html
9 path, filename = os.path.split(fullpath)
10 path, filename = os.path.split(fullpath)
10 filename, ext = os.path.splitext(filename)
11 filename, ext = os.path.splitext(filename)
11 sys.path.append(path)
12 sys.path.append(path)
12 module = __import__(filename)
13 module = __import__(filename)
13 reload(module) # Might be out of date during tests
14 importlib.reload(module) # Might be out of date during tests
14 del sys.path[-1]
15 del sys.path[-1]
15 return module
16 return module
@@ -1,263 +1,263 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3
3
4 import os
4 import os
5 import re
5 import re
6 import shutil
6 import shutil
7 import logging
7 import logging
8
8
9 from rhodecode.lib.dbmigrate.migrate import exceptions
9 from rhodecode.lib.dbmigrate.migrate import exceptions
10 from rhodecode.lib.dbmigrate.migrate.versioning import pathed, script
10 from rhodecode.lib.dbmigrate.migrate.versioning import pathed, script
11 from datetime import datetime
11 from datetime import datetime
12
12
13
13
14 log = logging.getLogger(__name__)
14 log = logging.getLogger(__name__)
15
15
16 class VerNum(object):
16 class VerNum(object):
17 """A version number that behaves like a string and int at the same time"""
17 """A version number that behaves like a string and int at the same time"""
18
18
19 _instances = {}
19 _instances = {}
20
20
21 def __new__(cls, value):
21 def __new__(cls, value):
22 val = str(value)
22 val = str(value)
23 if val not in cls._instances:
23 if val not in cls._instances:
24 cls._instances[val] = super(VerNum, cls).__new__(cls)
24 cls._instances[val] = super(VerNum, cls).__new__(cls)
25 ret = cls._instances[val]
25 ret = cls._instances[val]
26 return ret
26 return ret
27
27
28 def __init__(self,value):
28 def __init__(self,value):
29 self.value = str(int(value))
29 self.value = str(int(value))
30 if self < 0:
30 if self < 0:
31 raise ValueError("Version number cannot be negative")
31 raise ValueError("Version number cannot be negative")
32
32
33 def __add__(self, value):
33 def __add__(self, value):
34 ret = int(self) + int(value)
34 ret = int(self) + int(value)
35 return VerNum(ret)
35 return VerNum(ret)
36
36
37 def __sub__(self, value):
37 def __sub__(self, value):
38 return self + (int(value) * -1)
38 return self + (int(value) * -1)
39
39
40 def __eq__(self, value):
40 def __eq__(self, value):
41 return int(self) == int(value)
41 return int(self) == int(value)
42
42
43 def __ne__(self, value):
43 def __ne__(self, value):
44 return int(self) != int(value)
44 return int(self) != int(value)
45
45
46 def __lt__(self, value):
46 def __lt__(self, value):
47 return int(self) < int(value)
47 return int(self) < int(value)
48
48
49 def __gt__(self, value):
49 def __gt__(self, value):
50 return int(self) > int(value)
50 return int(self) > int(value)
51
51
52 def __ge__(self, value):
52 def __ge__(self, value):
53 return int(self) >= int(value)
53 return int(self) >= int(value)
54
54
55 def __le__(self, value):
55 def __le__(self, value):
56 return int(self) <= int(value)
56 return int(self) <= int(value)
57
57
58 def __repr__(self):
58 def __repr__(self):
59 return "<VerNum(%s)>" % self.value
59 return "<VerNum(%s)>" % self.value
60
60
61 def __str__(self):
61 def __str__(self):
62 return str(self.value)
62 return str(self.value)
63
63
64 def __int__(self):
64 def __int__(self):
65 return int(self.value)
65 return int(self.value)
66
66
67
67
68 class Collection(pathed.Pathed):
68 class Collection(pathed.Pathed):
69 """A collection of versioning scripts in a repository"""
69 """A collection of versioning scripts in a repository"""
70
70
71 FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
71 FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
72
72
73 def __init__(self, path):
73 def __init__(self, path):
74 """Collect current version scripts in repository
74 """Collect current version scripts in repository
75 and store them in self.versions
75 and store them in self.versions
76 """
76 """
77 super(Collection, self).__init__(path)
77 super(Collection, self).__init__(path)
78
78
79 # Create temporary list of files, allowing skipped version numbers.
79 # Create temporary list of files, allowing skipped version numbers.
80 files = os.listdir(path)
80 files = os.listdir(path)
81 if '1' in files:
81 if '1' in files:
82 # deprecation
82 # deprecation
83 raise Exception('It looks like you have a repository in the old '
83 raise Exception('It looks like you have a repository in the old '
84 'format (with directories for each version). '
84 'format (with directories for each version). '
85 'Please convert repository before proceeding.')
85 'Please convert repository before proceeding.')
86
86
87 tempVersions = {}
87 tempVersions = {}
88 for filename in files:
88 for filename in files:
89 match = self.FILENAME_WITH_VERSION.match(filename)
89 match = self.FILENAME_WITH_VERSION.match(filename)
90 if match:
90 if match:
91 num = int(match.group(1))
91 num = int(match.group(1))
92 tempVersions.setdefault(num, []).append(filename)
92 tempVersions.setdefault(num, []).append(filename)
93 else:
93 else:
94 pass # Must be a helper file or something, let's ignore it.
94 pass # Must be a helper file or something, let's ignore it.
95
95
96 # Create the versions member where the keys
96 # Create the versions member where the keys
97 # are VerNum's and the values are Version's.
97 # are VerNum's and the values are Version's.
98 self.versions = {}
98 self.versions = {}
99 for num, files in tempVersions.items():
99 for num, files in list(tempVersions.items()):
100 self.versions[VerNum(num)] = Version(num, path, files)
100 self.versions[VerNum(num)] = Version(num, path, files)
101
101
102 @property
102 @property
103 def latest(self):
103 def latest(self):
104 """:returns: Latest version in Collection"""
104 """:returns: Latest version in Collection"""
105 return max([VerNum(0)] + self.versions.keys())
105 return max([VerNum(0)] + list(self.versions.keys()))
106
106
107 def _next_ver_num(self, use_timestamp_numbering):
107 def _next_ver_num(self, use_timestamp_numbering):
108 if use_timestamp_numbering == True:
108 if use_timestamp_numbering == True:
109 return VerNum(int(datetime.utcnow().strftime('%Y%m%d%H%M%S')))
109 return VerNum(int(datetime.utcnow().strftime('%Y%m%d%H%M%S')))
110 else:
110 else:
111 return self.latest + 1
111 return self.latest + 1
112
112
113 def create_new_python_version(self, description, **k):
113 def create_new_python_version(self, description, **k):
114 """Create Python files for new version"""
114 """Create Python files for new version"""
115 ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
115 ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
116 extra = str_to_filename(description)
116 extra = str_to_filename(description)
117
117
118 if extra:
118 if extra:
119 if extra == '_':
119 if extra == '_':
120 extra = ''
120 extra = ''
121 elif not extra.startswith('_'):
121 elif not extra.startswith('_'):
122 extra = '_%s' % extra
122 extra = '_%s' % extra
123
123
124 filename = '%03d%s.py' % (ver, extra)
124 filename = '%03d%s.py' % (ver, extra)
125 filepath = self._version_path(filename)
125 filepath = self._version_path(filename)
126
126
127 script.PythonScript.create(filepath, **k)
127 script.PythonScript.create(filepath, **k)
128 self.versions[ver] = Version(ver, self.path, [filename])
128 self.versions[ver] = Version(ver, self.path, [filename])
129
129
130 def create_new_sql_version(self, database, description, **k):
130 def create_new_sql_version(self, database, description, **k):
131 """Create SQL files for new version"""
131 """Create SQL files for new version"""
132 ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
132 ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
133 self.versions[ver] = Version(ver, self.path, [])
133 self.versions[ver] = Version(ver, self.path, [])
134
134
135 extra = str_to_filename(description)
135 extra = str_to_filename(description)
136
136
137 if extra:
137 if extra:
138 if extra == '_':
138 if extra == '_':
139 extra = ''
139 extra = ''
140 elif not extra.startswith('_'):
140 elif not extra.startswith('_'):
141 extra = '_%s' % extra
141 extra = '_%s' % extra
142
142
143 # Create new files.
143 # Create new files.
144 for op in ('upgrade', 'downgrade'):
144 for op in ('upgrade', 'downgrade'):
145 filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op)
145 filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op)
146 filepath = self._version_path(filename)
146 filepath = self._version_path(filename)
147 script.SqlScript.create(filepath, **k)
147 script.SqlScript.create(filepath, **k)
148 self.versions[ver].add_script(filepath)
148 self.versions[ver].add_script(filepath)
149
149
150 def version(self, vernum=None):
150 def version(self, vernum=None):
151 """Returns latest Version if vernum is not given.
151 """Returns latest Version if vernum is not given.
152 Otherwise, returns wanted version"""
152 Otherwise, returns wanted version"""
153 if vernum is None:
153 if vernum is None:
154 vernum = self.latest
154 vernum = self.latest
155 return self.versions[VerNum(vernum)]
155 return self.versions[VerNum(vernum)]
156
156
157 @classmethod
157 @classmethod
158 def clear(cls):
158 def clear(cls):
159 super(Collection, cls).clear()
159 super(Collection, cls).clear()
160
160
161 def _version_path(self, ver):
161 def _version_path(self, ver):
162 """Returns path of file in versions repository"""
162 """Returns path of file in versions repository"""
163 return os.path.join(self.path, str(ver))
163 return os.path.join(self.path, str(ver))
164
164
165
165
166 class Version(object):
166 class Version(object):
167 """A single version in a collection
167 """A single version in a collection
168 :param vernum: Version Number
168 :param vernum: Version Number
169 :param path: Path to script files
169 :param path: Path to script files
170 :param filelist: List of scripts
170 :param filelist: List of scripts
171 :type vernum: int, VerNum
171 :type vernum: int, VerNum
172 :type path: string
172 :type path: string
173 :type filelist: list
173 :type filelist: list
174 """
174 """
175
175
176 def __init__(self, vernum, path, filelist):
176 def __init__(self, vernum, path, filelist):
177 self.version = VerNum(vernum)
177 self.version = VerNum(vernum)
178
178
179 # Collect scripts in this folder
179 # Collect scripts in this folder
180 self.sql = {}
180 self.sql = {}
181 self.python = None
181 self.python = None
182
182
183 for script in filelist:
183 for script in filelist:
184 self.add_script(os.path.join(path, script))
184 self.add_script(os.path.join(path, script))
185
185
186 def script(self, database=None, operation=None):
186 def script(self, database=None, operation=None):
187 """Returns SQL or Python Script"""
187 """Returns SQL or Python Script"""
188 for db in (database, 'default'):
188 for db in (database, 'default'):
189 # Try to return a .sql script first
189 # Try to return a .sql script first
190 try:
190 try:
191 return self.sql[db][operation]
191 return self.sql[db][operation]
192 except KeyError:
192 except KeyError:
193 continue # No .sql script exists
193 continue # No .sql script exists
194
194
195 # TODO: maybe add force Python parameter?
195 # TODO: maybe add force Python parameter?
196 ret = self.python
196 ret = self.python
197
197
198 assert ret is not None, \
198 assert ret is not None, \
199 "There is no script for %d version" % self.version
199 "There is no script for %d version" % self.version
200 return ret
200 return ret
201
201
202 def add_script(self, path):
202 def add_script(self, path):
203 """Add script to Collection/Version"""
203 """Add script to Collection/Version"""
204 if path.endswith(Extensions.py):
204 if path.endswith(Extensions.py):
205 self._add_script_py(path)
205 self._add_script_py(path)
206 elif path.endswith(Extensions.sql):
206 elif path.endswith(Extensions.sql):
207 self._add_script_sql(path)
207 self._add_script_sql(path)
208
208
209 SQL_FILENAME = re.compile(r'^.*\.sql')
209 SQL_FILENAME = re.compile(r'^.*\.sql')
210
210
211 def _add_script_sql(self, path):
211 def _add_script_sql(self, path):
212 basename = os.path.basename(path)
212 basename = os.path.basename(path)
213 match = self.SQL_FILENAME.match(basename)
213 match = self.SQL_FILENAME.match(basename)
214
214
215 if match:
215 if match:
216 basename = basename.replace('.sql', '')
216 basename = basename.replace('.sql', '')
217 parts = basename.split('_')
217 parts = basename.split('_')
218 if len(parts) < 3:
218 if len(parts) < 3:
219 raise exceptions.ScriptError(
219 raise exceptions.ScriptError(
220 "Invalid SQL script name %s " % basename + \
220 "Invalid SQL script name %s " % basename + \
221 "(needs to be ###_description_database_operation.sql)")
221 "(needs to be ###_description_database_operation.sql)")
222 version = parts[0]
222 version = parts[0]
223 op = parts[-1]
223 op = parts[-1]
224 # NOTE(mriedem): check for ibm_db_sa as the database in the name
224 # NOTE(mriedem): check for ibm_db_sa as the database in the name
225 if 'ibm_db_sa' in basename:
225 if 'ibm_db_sa' in basename:
226 if len(parts) == 6:
226 if len(parts) == 6:
227 dbms = '_'.join(parts[-4: -1])
227 dbms = '_'.join(parts[-4: -1])
228 else:
228 else:
229 raise exceptions.ScriptError(
229 raise exceptions.ScriptError(
230 "Invalid ibm_db_sa SQL script name '%s'; "
230 "Invalid ibm_db_sa SQL script name '%s'; "
231 "(needs to be "
231 "(needs to be "
232 "###_description_ibm_db_sa_operation.sql)" % basename)
232 "###_description_ibm_db_sa_operation.sql)" % basename)
233 else:
233 else:
234 dbms = parts[-2]
234 dbms = parts[-2]
235 else:
235 else:
236 raise exceptions.ScriptError(
236 raise exceptions.ScriptError(
237 "Invalid SQL script name %s " % basename + \
237 "Invalid SQL script name %s " % basename + \
238 "(needs to be ###_description_database_operation.sql)")
238 "(needs to be ###_description_database_operation.sql)")
239
239
240 # File the script into a dictionary
240 # File the script into a dictionary
241 self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
241 self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
242
242
243 def _add_script_py(self, path):
243 def _add_script_py(self, path):
244 if self.python is not None:
244 if self.python is not None:
245 raise exceptions.ScriptError('You can only have one Python script '
245 raise exceptions.ScriptError('You can only have one Python script '
246 'per version, but you have: %s and %s' % (self.python, path))
246 'per version, but you have: %s and %s' % (self.python, path))
247 self.python = script.PythonScript(path)
247 self.python = script.PythonScript(path)
248
248
249
249
250 class Extensions:
250 class Extensions:
251 """A namespace for file extensions"""
251 """A namespace for file extensions"""
252 py = 'py'
252 py = 'py'
253 sql = 'sql'
253 sql = 'sql'
254
254
255 def str_to_filename(s):
255 def str_to_filename(s):
256 """Replaces spaces, (double and single) quotes
256 """Replaces spaces, (double and single) quotes
257 and double underscores to underscores
257 and double underscores to underscores
258 """
258 """
259
259
260 s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
260 s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
261 while '__' in s:
261 while '__' in s:
262 s = s.replace('__', '_')
262 s = s.replace('__', '_')
263 return s
263 return s
General Comments 0
You need to be logged in to leave comments. Login now