##// END OF EJS Templates
update migrations for 1.2
marcink -
r1442:7f31de15 beta
parent child Browse files
Show More
@@ -1,9 +1,11 b''
1 1 """
2 2 SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for
3 3 database schema version and repository management and
4 4 :mod:`migrate.changeset` that allows to define database schema changes
5 5 using Python.
6 6 """
7 7
8 8 from rhodecode.lib.dbmigrate.migrate.versioning import *
9 9 from rhodecode.lib.dbmigrate.migrate.changeset import *
10
11 __version__ = '0.7.2.dev' No newline at end of file
@@ -1,29 +1,30 b''
1 1 """
2 2 This module extends SQLAlchemy and provides additional DDL [#]_
3 3 support.
4 4
5 5 .. [#] SQL Data Definition Language
6 6 """
7 7 import re
8 8 import warnings
9 9
10 10 import sqlalchemy
11 11 from sqlalchemy import __version__ as _sa_version
12 12
13 13 warnings.simplefilter('always', DeprecationWarning)
14 14
15 _sa_version = tuple(int(re.match("\d+", x).group(0))
15 _sa_version = tuple(int(re.match("\d+", x).group(0))
16 16 for x in _sa_version.split("."))
17 17 SQLA_06 = _sa_version >= (0, 6)
18 SQLA_07 = _sa_version >= (0, 7)
18 19
19 20 del re
20 21 del _sa_version
21 22
22 23 from rhodecode.lib.dbmigrate.migrate.changeset.schema import *
23 24 from rhodecode.lib.dbmigrate.migrate.changeset.constraint import *
24 25
25 26 sqlalchemy.schema.Table.__bases__ += (ChangesetTable,)
26 27 sqlalchemy.schema.Column.__bases__ += (ChangesetColumn,)
27 28 sqlalchemy.schema.Index.__bases__ += (ChangesetIndex,)
28 29
29 30 sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause,)
@@ -1,651 +1,657 b''
1 1 """
2 2 Schema module providing common schema operations.
3 3 """
4 4 import warnings
5 5
6 6 from UserDict import DictMixin
7 7
8 8 import sqlalchemy
9 9
10 10 from sqlalchemy.schema import ForeignKeyConstraint
11 11 from sqlalchemy.schema import UniqueConstraint
12 12
13 13 from rhodecode.lib.dbmigrate.migrate.exceptions import *
14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06, SQLA_07
15 15 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor,
16 run_single_visitor)
16 run_single_visitor)
17 17
18 18
19 19 __all__ = [
20 20 'create_column',
21 21 'drop_column',
22 22 'alter_column',
23 23 'rename_table',
24 24 'rename_index',
25 25 'ChangesetTable',
26 26 'ChangesetColumn',
27 27 'ChangesetIndex',
28 28 'ChangesetDefaultClause',
29 29 'ColumnDelta',
30 30 ]
31 31
32 32 def create_column(column, table=None, *p, **kw):
33 33 """Create a column, given the table.
34 34
35 35 API to :meth:`ChangesetColumn.create`.
36 36 """
37 37 if table is not None:
38 38 return table.create_column(column, *p, **kw)
39 39 return column.create(*p, **kw)
40 40
41 41
42 42 def drop_column(column, table=None, *p, **kw):
43 43 """Drop a column, given the table.
44 44
45 45 API to :meth:`ChangesetColumn.drop`.
46 46 """
47 47 if table is not None:
48 48 return table.drop_column(column, *p, **kw)
49 49 return column.drop(*p, **kw)
50 50
51 51
52 52 def rename_table(table, name, engine=None, **kw):
53 53 """Rename a table.
54 54
55 55 If Table instance is given, engine is not used.
56 56
57 57 API to :meth:`ChangesetTable.rename`.
58 58
59 59 :param table: Table to be renamed.
60 60 :param name: New name for Table.
61 61 :param engine: Engine instance.
62 62 :type table: string or Table instance
63 63 :type name: string
64 64 :type engine: obj
65 65 """
66 66 table = _to_table(table, engine)
67 67 table.rename(name, **kw)
68 68
69 69
70 70 def rename_index(index, name, table=None, engine=None, **kw):
71 71 """Rename an index.
72 72
73 73 If Index instance is given,
74 74 table and engine are not used.
75 75
76 76 API to :meth:`ChangesetIndex.rename`.
77 77
78 78 :param index: Index to be renamed.
79 79 :param name: New name for index.
80 80 :param table: Table to which Index is reffered.
81 81 :param engine: Engine instance.
82 82 :type index: string or Index instance
83 83 :type name: string
84 84 :type table: string or Table instance
85 85 :type engine: obj
86 86 """
87 87 index = _to_index(index, table, engine)
88 88 index.rename(name, **kw)
89 89
90 90
91 91 def alter_column(*p, **k):
92 92 """Alter a column.
93 93
94 94 This is a helper function that creates a :class:`ColumnDelta` and
95 95 runs it.
96 96
97 97 :argument column:
98 98 The name of the column to be altered or a
99 99 :class:`ChangesetColumn` column representing it.
100 100
101 101 :param table:
102 102 A :class:`~sqlalchemy.schema.Table` or table name to
103 103 for the table where the column will be changed.
104 104
105 105 :param engine:
106 106 The :class:`~sqlalchemy.engine.base.Engine` to use for table
107 107 reflection and schema alterations.
108 108
109 109 :returns: A :class:`ColumnDelta` instance representing the change.
110 110
111 111
112 112 """
113 113
114 114 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
115 115 k['table'] = p[0].table
116 116 if 'engine' not in k:
117 117 k['engine'] = k['table'].bind
118 118
119 119 # deprecation
120 120 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
121 121 warnings.warn(
122 122 "Passing a Column object to alter_column is deprecated."
123 123 " Just pass in keyword parameters instead.",
124 124 MigrateDeprecationWarning
125 125 )
126 126 engine = k['engine']
127 127
128 128 # enough tests seem to break when metadata is always altered
129 129 # that this crutch has to be left in until they can be sorted
130 130 # out
131 131 k['alter_metadata']=True
132 132
133 133 delta = ColumnDelta(*p, **k)
134 134
135 135 visitorcallable = get_engine_visitor(engine, 'schemachanger')
136 136 engine._run_visitor(visitorcallable, delta)
137 137
138 138 return delta
139 139
140 140
141 141 def _to_table(table, engine=None):
142 142 """Return if instance of Table, else construct new with metadata"""
143 143 if isinstance(table, sqlalchemy.Table):
144 144 return table
145 145
146 146 # Given: table name, maybe an engine
147 147 meta = sqlalchemy.MetaData()
148 148 if engine is not None:
149 149 meta.bind = engine
150 150 return sqlalchemy.Table(table, meta)
151 151
152 152
153 153 def _to_index(index, table=None, engine=None):
154 154 """Return if instance of Index, else construct new with metadata"""
155 155 if isinstance(index, sqlalchemy.Index):
156 156 return index
157 157
158 158 # Given: index name; table name required
159 159 table = _to_table(table, engine)
160 160 ret = sqlalchemy.Index(index)
161 161 ret.table = table
162 162 return ret
163 163
164 164
165 165 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
166 166 """Extracts the differences between two columns/column-parameters
167 167
168 168 May receive parameters arranged in several different ways:
169 169
170 170 * **current_column, new_column, \*p, \*\*kw**
171 171 Additional parameters can be specified to override column
172 172 differences.
173 173
174 174 * **current_column, \*p, \*\*kw**
175 175 Additional parameters alter current_column. Table name is extracted
176 176 from current_column object.
177 177 Name is changed to current_column.name from current_name,
178 178 if current_name is specified.
179 179
180 180 * **current_col_name, \*p, \*\*kw**
181 181 Table kw must specified.
182 182
183 183 :param table: Table at which current Column should be bound to.\
184 184 If table name is given, reflection will be used.
185 185 :type table: string or Table instance
186 186
187 187 :param metadata: A :class:`MetaData` instance to store
188 188 reflected table names
189 189
190 190 :param engine: When reflecting tables, either engine or metadata must \
191 191 be specified to acquire engine object.
192 192 :type engine: :class:`Engine` instance
193 193 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
194 194 `result_column` through :func:`dict` alike object.
195 195
196 196 * :class:`ColumnDelta`.result_column is altered column with new attributes
197 197
198 198 * :class:`ColumnDelta`.current_name is current name of column in db
199 199
200 200
201 201 """
202 202
203 203 # Column attributes that can be altered
204 204 diff_keys = ('name', 'type', 'primary_key', 'nullable',
205 205 'server_onupdate', 'server_default', 'autoincrement')
206 206 diffs = dict()
207 207 __visit_name__ = 'column'
208 208
209 209 def __init__(self, *p, **kw):
210 210 # 'alter_metadata' is not a public api. It exists purely
211 211 # as a crutch until the tests that fail when 'alter_metadata'
212 212 # behaviour always happens can be sorted out
213 213 self.alter_metadata = kw.pop("alter_metadata", False)
214 214
215 215 self.meta = kw.pop("metadata", None)
216 216 self.engine = kw.pop("engine", None)
217 217
218 218 # Things are initialized differently depending on how many column
219 219 # parameters are given. Figure out how many and call the appropriate
220 220 # method.
221 221 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
222 222 # At least one column specified
223 223 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
224 224 # Two columns specified
225 225 diffs = self.compare_2_columns(*p, **kw)
226 226 else:
227 227 # Exactly one column specified
228 228 diffs = self.compare_1_column(*p, **kw)
229 229 else:
230 230 # Zero columns specified
231 231 if not len(p) or not isinstance(p[0], basestring):
232 232 raise ValueError("First argument must be column name")
233 233 diffs = self.compare_parameters(*p, **kw)
234 234
235 235 self.apply_diffs(diffs)
236 236
237 237 def __repr__(self):
238 238 return '<ColumnDelta altermetadata=%r, %s>' % (
239 239 self.alter_metadata,
240 240 super(ColumnDelta, self).__repr__()
241 241 )
242 242
243 243 def __getitem__(self, key):
244 244 if key not in self.keys():
245 245 raise KeyError("No such diff key, available: %s" % self.diffs )
246 246 return getattr(self.result_column, key)
247 247
248 248 def __setitem__(self, key, value):
249 249 if key not in self.keys():
250 250 raise KeyError("No such diff key, available: %s" % self.diffs )
251 251 setattr(self.result_column, key, value)
252 252
253 253 def __delitem__(self, key):
254 254 raise NotImplementedError
255 255
256 256 def keys(self):
257 257 return self.diffs.keys()
258 258
259 259 def compare_parameters(self, current_name, *p, **k):
260 260 """Compares Column objects with reflection"""
261 261 self.table = k.pop('table')
262 262 self.result_column = self._table.c.get(current_name)
263 263 if len(p):
264 264 k = self._extract_parameters(p, k, self.result_column)
265 265 return k
266 266
267 267 def compare_1_column(self, col, *p, **k):
268 268 """Compares one Column object"""
269 269 self.table = k.pop('table', None)
270 270 if self.table is None:
271 271 self.table = col.table
272 272 self.result_column = col
273 273 if len(p):
274 274 k = self._extract_parameters(p, k, self.result_column)
275 275 return k
276 276
277 277 def compare_2_columns(self, old_col, new_col, *p, **k):
278 278 """Compares two Column objects"""
279 279 self.process_column(new_col)
280 280 self.table = k.pop('table', None)
281 281 # we cannot use bool() on table in SA06
282 282 if self.table is None:
283 283 self.table = old_col.table
284 284 if self.table is None:
285 285 new_col.table
286 286 self.result_column = old_col
287 287
288 288 # set differences
289 289 # leave out some stuff for later comp
290 290 for key in (set(self.diff_keys) - set(('type',))):
291 291 val = getattr(new_col, key, None)
292 292 if getattr(self.result_column, key, None) != val:
293 293 k.setdefault(key, val)
294 294
295 295 # inspect types
296 296 if not self.are_column_types_eq(self.result_column.type, new_col.type):
297 297 k.setdefault('type', new_col.type)
298 298
299 299 if len(p):
300 300 k = self._extract_parameters(p, k, self.result_column)
301 301 return k
302 302
303 303 def apply_diffs(self, diffs):
304 304 """Populate dict and column object with new values"""
305 305 self.diffs = diffs
306 306 for key in self.diff_keys:
307 307 if key in diffs:
308 308 setattr(self.result_column, key, diffs[key])
309 309
310 310 self.process_column(self.result_column)
311 311
312 312 # create an instance of class type if not yet
313 313 if 'type' in diffs and callable(self.result_column.type):
314 314 self.result_column.type = self.result_column.type()
315 315
316 316 # add column to the table
317 317 if self.table is not None and self.alter_metadata:
318 318 self.result_column.add_to_table(self.table)
319 319
320 320 def are_column_types_eq(self, old_type, new_type):
321 321 """Compares two types to be equal"""
322 322 ret = old_type.__class__ == new_type.__class__
323 323
324 324 # String length is a special case
325 325 if ret and isinstance(new_type, sqlalchemy.types.String):
326 326 ret = (getattr(old_type, 'length', None) == \
327 327 getattr(new_type, 'length', None))
328 328 return ret
329 329
330 330 def _extract_parameters(self, p, k, column):
331 331 """Extracts data from p and modifies diffs"""
332 332 p = list(p)
333 333 while len(p):
334 334 if isinstance(p[0], basestring):
335 335 k.setdefault('name', p.pop(0))
336 336 elif isinstance(p[0], sqlalchemy.types.AbstractType):
337 337 k.setdefault('type', p.pop(0))
338 338 elif callable(p[0]):
339 339 p[0] = p[0]()
340 340 else:
341 341 break
342 342
343 343 if len(p):
344 344 new_col = column.copy_fixed()
345 345 new_col._init_items(*p)
346 346 k = self.compare_2_columns(column, new_col, **k)
347 347 return k
348 348
349 349 def process_column(self, column):
350 350 """Processes default values for column"""
351 351 # XXX: this is a snippet from SA processing of positional parameters
352 352 if not SQLA_06 and column.args:
353 353 toinit = list(column.args)
354 354 else:
355 355 toinit = list()
356 356
357 357 if column.server_default is not None:
358 358 if isinstance(column.server_default, sqlalchemy.FetchedValue):
359 359 toinit.append(column.server_default)
360 360 else:
361 361 toinit.append(sqlalchemy.DefaultClause(column.server_default))
362 362 if column.server_onupdate is not None:
363 363 if isinstance(column.server_onupdate, FetchedValue):
364 364 toinit.append(column.server_default)
365 365 else:
366 366 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
367 367 for_update=True))
368 368 if toinit:
369 369 column._init_items(*toinit)
370 370
371 371 if not SQLA_06:
372 372 column.args = []
373 373
374 374 def _get_table(self):
375 375 return getattr(self, '_table', None)
376 376
377 377 def _set_table(self, table):
378 378 if isinstance(table, basestring):
379 379 if self.alter_metadata:
380 380 if not self.meta:
381 381 raise ValueError("metadata must be specified for table"
382 382 " reflection when using alter_metadata")
383 383 meta = self.meta
384 384 if self.engine:
385 385 meta.bind = self.engine
386 386 else:
387 387 if not self.engine and not self.meta:
388 388 raise ValueError("engine or metadata must be specified"
389 389 " to reflect tables")
390 390 if not self.engine:
391 391 self.engine = self.meta.bind
392 392 meta = sqlalchemy.MetaData(bind=self.engine)
393 393 self._table = sqlalchemy.Table(table, meta, autoload=True)
394 394 elif isinstance(table, sqlalchemy.Table):
395 395 self._table = table
396 396 if not self.alter_metadata:
397 397 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
398 398 def _get_result_column(self):
399 399 return getattr(self, '_result_column', None)
400 400
401 401 def _set_result_column(self, column):
402 402 """Set Column to Table based on alter_metadata evaluation."""
403 403 self.process_column(column)
404 404 if not hasattr(self, 'current_name'):
405 405 self.current_name = column.name
406 406 if self.alter_metadata:
407 407 self._result_column = column
408 408 else:
409 409 self._result_column = column.copy_fixed()
410 410
411 411 table = property(_get_table, _set_table)
412 412 result_column = property(_get_result_column, _set_result_column)
413 413
414 414
415 415 class ChangesetTable(object):
416 416 """Changeset extensions to SQLAlchemy tables."""
417 417
418 418 def create_column(self, column, *p, **kw):
419 419 """Creates a column.
420 420
421 421 The column parameter may be a column definition or the name of
422 422 a column in this table.
423 423
424 424 API to :meth:`ChangesetColumn.create`
425 425
426 426 :param column: Column to be created
427 427 :type column: Column instance or string
428 428 """
429 429 if not isinstance(column, sqlalchemy.Column):
430 430 # It's a column name
431 431 column = getattr(self.c, str(column))
432 432 column.create(table=self, *p, **kw)
433 433
434 434 def drop_column(self, column, *p, **kw):
435 435 """Drop a column, given its name or definition.
436 436
437 437 API to :meth:`ChangesetColumn.drop`
438 438
439 439 :param column: Column to be droped
440 440 :type column: Column instance or string
441 441 """
442 442 if not isinstance(column, sqlalchemy.Column):
443 443 # It's a column name
444 444 try:
445 445 column = getattr(self.c, str(column))
446 446 except AttributeError:
447 447 # That column isn't part of the table. We don't need
448 448 # its entire definition to drop the column, just its
449 449 # name, so create a dummy column with the same name.
450 450 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
451 451 column.drop(table=self, *p, **kw)
452 452
453 453 def rename(self, name, connection=None, **kwargs):
454 454 """Rename this table.
455 455
456 456 :param name: New name of the table.
457 457 :type name: string
458 458 :param connection: reuse connection istead of creating new one.
459 459 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
460 460 """
461 461 engine = self.bind
462 462 self.new_name = name
463 463 visitorcallable = get_engine_visitor(engine, 'schemachanger')
464 464 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
465 465
466 466 # Fix metadata registration
467 467 self.name = name
468 468 self.deregister()
469 469 self._set_parent(self.metadata)
470 470
471 471 def _meta_key(self):
472 472 return sqlalchemy.schema._get_table_key(self.name, self.schema)
473 473
474 474 def deregister(self):
475 475 """Remove this table from its metadata"""
476 476 key = self._meta_key()
477 477 meta = self.metadata
478 478 if key in meta.tables:
479 479 del meta.tables[key]
480 480
481 481
482 482 class ChangesetColumn(object):
483 483 """Changeset extensions to SQLAlchemy columns."""
484 484
485 485 def alter(self, *p, **k):
486 486 """Makes a call to :func:`alter_column` for the column this
487 487 method is called on.
488 488 """
489 489 if 'table' not in k:
490 490 k['table'] = self.table
491 491 if 'engine' not in k:
492 492 k['engine'] = k['table'].bind
493 493 return alter_column(self, *p, **k)
494 494
495 495 def create(self, table=None, index_name=None, unique_name=None,
496 496 primary_key_name=None, populate_default=True, connection=None, **kwargs):
497 497 """Create this column in the database.
498 498
499 499 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
500 500 for most databases.
501 501
502 502 :param table: Table instance to create on.
503 503 :param index_name: Creates :class:`ChangesetIndex` on this column.
504 504 :param unique_name: Creates :class:\
505 505 `~migrate.changeset.constraint.UniqueConstraint` on this column.
506 506 :param primary_key_name: Creates :class:\
507 507 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
508 508 :param populate_default: If True, created column will be \
509 509 populated with defaults
510 510 :param connection: reuse connection istead of creating new one.
511 511 :type table: Table instance
512 512 :type index_name: string
513 513 :type unique_name: string
514 514 :type primary_key_name: string
515 515 :type populate_default: bool
516 516 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
517 517
518 518 :returns: self
519 519 """
520 520 self.populate_default = populate_default
521 521 self.index_name = index_name
522 522 self.unique_name = unique_name
523 523 self.primary_key_name = primary_key_name
524 524 for cons in ('index_name', 'unique_name', 'primary_key_name'):
525 525 self._check_sanity_constraints(cons)
526 526
527 527 self.add_to_table(table)
528 528 engine = self.table.bind
529 529 visitorcallable = get_engine_visitor(engine, 'columngenerator')
530 530 engine._run_visitor(visitorcallable, self, connection, **kwargs)
531 531
532 532 # TODO: reuse existing connection
533 533 if self.populate_default and self.default is not None:
534 534 stmt = table.update().values({self: engine._execute_default(self.default)})
535 535 engine.execute(stmt)
536 536
537 537 return self
538 538
539 539 def drop(self, table=None, connection=None, **kwargs):
540 540 """Drop this column from the database, leaving its table intact.
541 541
542 542 ``ALTER TABLE DROP COLUMN``, for most databases.
543 543
544 544 :param connection: reuse connection istead of creating new one.
545 545 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
546 546 """
547 547 if table is not None:
548 548 self.table = table
549 549 engine = self.table.bind
550 550 visitorcallable = get_engine_visitor(engine, 'columndropper')
551 551 engine._run_visitor(visitorcallable, self, connection, **kwargs)
552 552 self.remove_from_table(self.table, unset_table=False)
553 553 self.table = None
554 554 return self
555 555
556 556 def add_to_table(self, table):
557 557 if table is not None and self.table is None:
558 self._set_parent(table)
558 if SQLA_07:
559 table.append_column(self)
560 else:
561 self._set_parent(table)
559 562
560 563 def _col_name_in_constraint(self,cons,name):
561 564 return False
562 565
563 566 def remove_from_table(self, table, unset_table=True):
564 567 # TODO: remove primary keys, constraints, etc
565 568 if unset_table:
566 569 self.table = None
567 570
568 571 to_drop = set()
569 572 for index in table.indexes:
570 573 columns = []
571 574 for col in index.columns:
572 575 if col.name!=self.name:
573 576 columns.append(col)
574 577 if columns:
575 578 index.columns=columns
576 579 else:
577 580 to_drop.add(index)
578 581 table.indexes = table.indexes - to_drop
579 582
580 583 to_drop = set()
581 584 for cons in table.constraints:
582 585 # TODO: deal with other types of constraint
583 586 if isinstance(cons,(ForeignKeyConstraint,
584 587 UniqueConstraint)):
585 588 for col_name in cons.columns:
586 589 if not isinstance(col_name,basestring):
587 590 col_name = col_name.name
588 591 if self.name==col_name:
589 592 to_drop.add(cons)
590 593 table.constraints = table.constraints - to_drop
591 594
592 595 if table.c.contains_column(self):
593 table.c.remove(self)
596 if SQLA_07:
597 table._columns.remove(self)
598 else:
599 table.c.remove(self)
594 600
595 601 # TODO: this is fixed in 0.6
596 602 def copy_fixed(self, **kw):
597 603 """Create a copy of this ``Column``, with all attributes."""
598 604 return sqlalchemy.Column(self.name, self.type, self.default,
599 605 key=self.key,
600 606 primary_key=self.primary_key,
601 607 nullable=self.nullable,
602 608 quote=self.quote,
603 609 index=self.index,
604 610 unique=self.unique,
605 611 onupdate=self.onupdate,
606 612 autoincrement=self.autoincrement,
607 613 server_default=self.server_default,
608 614 server_onupdate=self.server_onupdate,
609 615 *[c.copy(**kw) for c in self.constraints])
610 616
611 617 def _check_sanity_constraints(self, name):
612 618 """Check if constraints names are correct"""
613 619 obj = getattr(self, name)
614 620 if (getattr(self, name[:-5]) and not obj):
615 621 raise InvalidConstraintError("Column.create() accepts index_name,"
616 622 " primary_key_name and unique_name to generate constraints")
617 623 if not isinstance(obj, basestring) and obj is not None:
618 624 raise InvalidConstraintError(
619 625 "%s argument for column must be constraint name" % name)
620 626
621 627
622 628 class ChangesetIndex(object):
623 629 """Changeset extensions to SQLAlchemy Indexes."""
624 630
625 631 __visit_name__ = 'index'
626 632
627 633 def rename(self, name, connection=None, **kwargs):
628 634 """Change the name of an index.
629 635
630 636 :param name: New name of the Index.
631 637 :type name: string
632 638 :param connection: reuse connection istead of creating new one.
633 639 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
634 640 """
635 641 engine = self.table.bind
636 642 self.new_name = name
637 643 visitorcallable = get_engine_visitor(engine, 'schemachanger')
638 644 engine._run_visitor(visitorcallable, self, connection, **kwargs)
639 645 self.name = name
640 646
641 647
642 648 class ChangesetDefaultClause(object):
643 649 """Implements comparison between :class:`DefaultClause` instances"""
644 650
645 651 def __eq__(self, other):
646 652 if isinstance(other, self.__class__):
647 653 if self.arg == other.arg:
648 654 return True
649 655
650 656 def __ne__(self, other):
651 657 return not self.__eq__(other)
@@ -1,83 +1,88 b''
1 1 """
2 2 Provide exception classes for :mod:`migrate`
3 3 """
4 4
5 5
6 6 class Error(Exception):
7 7 """Error base class."""
8 8
9 9
10 10 class ApiError(Error):
11 11 """Base class for API errors."""
12 12
13 13
14 14 class KnownError(ApiError):
15 15 """A known error condition."""
16 16
17 17
18 18 class UsageError(ApiError):
19 19 """A known error condition where help should be displayed."""
20 20
21 21
22 22 class ControlledSchemaError(Error):
23 23 """Base class for controlled schema errors."""
24 24
25 25
26 26 class InvalidVersionError(ControlledSchemaError):
27 27 """Invalid version number."""
28 28
29 29
30 30 class DatabaseNotControlledError(ControlledSchemaError):
31 31 """Database should be under version control, but it's not."""
32 32
33 33
34 34 class DatabaseAlreadyControlledError(ControlledSchemaError):
35 35 """Database shouldn't be under version control, but it is"""
36 36
37 37
38 38 class WrongRepositoryError(ControlledSchemaError):
39 39 """This database is under version control by another repository."""
40 40
41 41
42 42 class NoSuchTableError(ControlledSchemaError):
43 43 """The table does not exist."""
44 44
45 45
46 46 class PathError(Error):
47 47 """Base class for path errors."""
48 48
49 49
50 50 class PathNotFoundError(PathError):
51 51 """A path with no file was required; found a file."""
52 52
53 53
54 54 class PathFoundError(PathError):
55 55 """A path with a file was required; found no file."""
56 56
57 57
58 58 class RepositoryError(Error):
59 59 """Base class for repository errors."""
60 60
61 61
62 62 class InvalidRepositoryError(RepositoryError):
63 63 """Invalid repository error."""
64 64
65 65
66 66 class ScriptError(Error):
67 67 """Base class for script errors."""
68 68
69 69
70 70 class InvalidScriptError(ScriptError):
71 71 """Invalid script error."""
72 72
73 73
74 class InvalidVersionError(Error):
75 """Invalid version error."""
76
77 # migrate.changeset
78
74 79 class NotSupportedError(Error):
75 80 """Not supported error"""
76 81
77 82
78 83 class InvalidConstraintError(Error):
79 84 """Invalid constraint error"""
80 85
81 86
82 87 class MigrateDeprecationWarning(DeprecationWarning):
83 88 """Warning for deprecated features in Migrate"""
@@ -1,383 +1,383 b''
1 1 """
2 2 This module provides an external API to the versioning system.
3 3
4 4 .. versionchanged:: 0.6.0
5 5 :func:`migrate.versioning.api.test` and schema diff functions
6 6 changed order of positional arguments so all accept `url` and `repository`
7 7 as first arguments.
8 8
9 9 .. versionchanged:: 0.5.4
10 10 ``--preview_sql`` displays source file when using SQL scripts.
11 11 If Python script is used, it runs the action with mocked engine and
12 12 returns captured SQL statements.
13 13
14 14 .. versionchanged:: 0.5.4
15 15 Deprecated ``--echo`` parameter in favour of new
16 16 :func:`migrate.versioning.util.construct_engine` behavior.
17 17 """
18 18
19 19 # Dear migrate developers,
20 20 #
21 21 # please do not comment this module using sphinx syntax because its
22 22 # docstrings are presented as user help and most users cannot
23 23 # interpret sphinx annotated ReStructuredText.
24 24 #
25 25 # Thanks,
26 26 # Jan Dittberner
27 27
28 28 import sys
29 29 import inspect
30 30 import logging
31 31
32 32 from rhodecode.lib.dbmigrate.migrate import exceptions
33 33 from rhodecode.lib.dbmigrate.migrate.versioning import repository, schema, version, \
34 34 script as script_ # command name conflict
35 35 from rhodecode.lib.dbmigrate.migrate.versioning.util import catch_known_errors, with_engine
36 36
37 37
38 38 log = logging.getLogger(__name__)
39 39 command_desc = {
40 40 'help': 'displays help on a given command',
41 41 'create': 'create an empty repository at the specified path',
42 42 'script': 'create an empty change Python script',
43 43 'script_sql': 'create empty change SQL scripts for given database',
44 44 'version': 'display the latest version available in a repository',
45 45 'db_version': 'show the current version of the repository under version control',
46 46 'source': 'display the Python code for a particular version in this repository',
47 47 'version_control': 'mark a database as under this repository\'s version control',
48 48 'upgrade': 'upgrade a database to a later version',
49 49 'downgrade': 'downgrade a database to an earlier version',
50 50 'drop_version_control': 'removes version control from a database',
51 51 'manage': 'creates a Python script that runs Migrate with a set of default values',
52 52 'test': 'performs the upgrade and downgrade command on the given database',
53 53 'compare_model_to_db': 'compare MetaData against the current database state',
54 54 'create_model': 'dump the current database as a Python model to stdout',
55 55 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
56 56 'update_db_from_model': 'modify the database to match the structure of the current MetaData',
57 57 }
58 58 __all__ = command_desc.keys()
59 59
60 60 Repository = repository.Repository
61 61 ControlledSchema = schema.ControlledSchema
62 62 VerNum = version.VerNum
63 63 PythonScript = script_.PythonScript
64 64 SqlScript = script_.SqlScript
65 65
66 66
67 67 # deprecated
68 68 def help(cmd=None, **opts):
69 69 """%prog help COMMAND
70 70
71 71 Displays help on a given command.
72 72 """
73 73 if cmd is None:
74 74 raise exceptions.UsageError(None)
75 75 try:
76 76 func = globals()[cmd]
77 77 except:
78 78 raise exceptions.UsageError(
79 79 "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
80 80 ret = func.__doc__
81 81 if sys.argv[0]:
82 82 ret = ret.replace('%prog', sys.argv[0])
83 83 return ret
84 84
85 85 @catch_known_errors
86 86 def create(repository, name, **opts):
87 87 """%prog create REPOSITORY_PATH NAME [--table=TABLE]
88 88
89 89 Create an empty repository at the specified path.
90 90
91 91 You can specify the version_table to be used; by default, it is
92 92 'migrate_version'. This table is created in all version-controlled
93 93 databases.
94 94 """
95 95 repo_path = Repository.create(repository, name, **opts)
96 96
97 97
98 98 @catch_known_errors
99 99 def script(description, repository, **opts):
100 100 """%prog script DESCRIPTION REPOSITORY_PATH
101 101
102 102 Create an empty change script using the next unused version number
103 103 appended with the given description.
104 104
105 105 For instance, manage.py script "Add initial tables" creates:
106 106 repository/versions/001_Add_initial_tables.py
107 107 """
108 108 repo = Repository(repository)
109 109 repo.create_script(description, **opts)
110 110
111 111
112 112 @catch_known_errors
113 def script_sql(database, repository, **opts):
114 """%prog script_sql DATABASE REPOSITORY_PATH
113 def script_sql(database, description, repository, **opts):
114 """%prog script_sql DATABASE DESCRIPTION REPOSITORY_PATH
115 115
116 116 Create empty change SQL scripts for given DATABASE, where DATABASE
117 is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.)
117 is either specific ('postgresql', 'mysql', 'oracle', 'sqlite', etc.)
118 118 or generic ('default').
119 119
120 For instance, manage.py script_sql postgres creates:
121 repository/versions/001_postgres_upgrade.sql and
122 repository/versions/001_postgres_postgres.sql
120 For instance, manage.py script_sql postgresql description creates:
121 repository/versions/001_description_postgresql_upgrade.sql and
122 repository/versions/001_description_postgresql_postgres.sql
123 123 """
124 124 repo = Repository(repository)
125 repo.create_script_sql(database, **opts)
125 repo.create_script_sql(database, description, **opts)
126 126
127 127
128 128 def version(repository, **opts):
129 129 """%prog version REPOSITORY_PATH
130 130
131 131 Display the latest version available in a repository.
132 132 """
133 133 repo = Repository(repository)
134 134 return repo.latest
135 135
136 136
137 137 @with_engine
138 138 def db_version(url, repository, **opts):
139 139 """%prog db_version URL REPOSITORY_PATH
140 140
141 141 Show the current version of the repository with the given
142 142 connection string, under version control of the specified
143 143 repository.
144 144
145 145 The url should be any valid SQLAlchemy connection string.
146 146 """
147 147 engine = opts.pop('engine')
148 148 schema = ControlledSchema(engine, repository)
149 149 return schema.version
150 150
151 151
152 152 def source(version, dest=None, repository=None, **opts):
153 153 """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
154 154
155 155 Display the Python code for a particular version in this
156 156 repository. Save it to the file at DESTINATION or, if omitted,
157 157 send to stdout.
158 158 """
159 159 if repository is None:
160 160 raise exceptions.UsageError("A repository must be specified")
161 161 repo = Repository(repository)
162 162 ret = repo.version(version).script().source()
163 163 if dest is not None:
164 164 dest = open(dest, 'w')
165 165 dest.write(ret)
166 166 dest.close()
167 167 ret = None
168 168 return ret
169 169
170 170
171 171 def upgrade(url, repository, version=None, **opts):
172 172 """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
173 173
174 174 Upgrade a database to a later version.
175 175
176 176 This runs the upgrade() function defined in your change scripts.
177 177
178 178 By default, the database is updated to the latest available
179 179 version. You may specify a version instead, if you wish.
180 180
181 181 You may preview the Python or SQL code to be executed, rather than
182 182 actually executing it, using the appropriate 'preview' option.
183 183 """
184 184 err = "Cannot upgrade a database of version %s to version %s. "\
185 185 "Try 'downgrade' instead."
186 186 return _migrate(url, repository, version, upgrade=True, err=err, **opts)
187 187
188 188
189 189 def downgrade(url, repository, version, **opts):
190 190 """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
191 191
192 192 Downgrade a database to an earlier version.
193 193
194 194 This is the reverse of upgrade; this runs the downgrade() function
195 195 defined in your change scripts.
196 196
197 197 You may preview the Python or SQL code to be executed, rather than
198 198 actually executing it, using the appropriate 'preview' option.
199 199 """
200 200 err = "Cannot downgrade a database of version %s to version %s. "\
201 201 "Try 'upgrade' instead."
202 202 return _migrate(url, repository, version, upgrade=False, err=err, **opts)
203 203
204 204 @with_engine
205 205 def test(url, repository, **opts):
206 206 """%prog test URL REPOSITORY_PATH [VERSION]
207 207
208 208 Performs the upgrade and downgrade option on the given
209 209 database. This is not a real test and may leave the database in a
210 210 bad state. You should therefore better run the test on a copy of
211 211 your database.
212 212 """
213 213 engine = opts.pop('engine')
214 214 repos = Repository(repository)
215 215 script = repos.version(None).script()
216 216
217 217 # Upgrade
218 218 log.info("Upgrading...")
219 219 script.run(engine, 1)
220 220 log.info("done")
221 221
222 222 log.info("Downgrading...")
223 223 script.run(engine, -1)
224 224 log.info("done")
225 225 log.info("Success")
226 226
227 227
228 228 @with_engine
229 229 def version_control(url, repository, version=None, **opts):
230 230 """%prog version_control URL REPOSITORY_PATH [VERSION]
231 231
232 232 Mark a database as under this repository's version control.
233 233
234 234 Once a database is under version control, schema changes should
235 235 only be done via change scripts in this repository.
236 236
237 237 This creates the table version_table in the database.
238 238
239 239 The url should be any valid SQLAlchemy connection string.
240 240
241 241 By default, the database begins at version 0 and is assumed to be
242 242 empty. If the database is not empty, you may specify a version at
243 243 which to begin instead. No attempt is made to verify this
244 244 version's correctness - the database schema is expected to be
245 245 identical to what it would be if the database were created from
246 246 scratch.
247 247 """
248 248 engine = opts.pop('engine')
249 249 ControlledSchema.create(engine, repository, version)
250 250
251 251
252 252 @with_engine
253 253 def drop_version_control(url, repository, **opts):
254 254 """%prog drop_version_control URL REPOSITORY_PATH
255 255
256 256 Removes version control from a database.
257 257 """
258 258 engine = opts.pop('engine')
259 259 schema = ControlledSchema(engine, repository)
260 260 schema.drop()
261 261
262 262
263 263 def manage(file, **opts):
264 264 """%prog manage FILENAME [VARIABLES...]
265 265
266 266 Creates a script that runs Migrate with a set of default values.
267 267
268 268 For example::
269 269
270 270 %prog manage manage.py --repository=/path/to/repository \
271 271 --url=sqlite:///project.db
272 272
273 273 would create the script manage.py. The following two commands
274 274 would then have exactly the same results::
275 275
276 276 python manage.py version
277 277 %prog version --repository=/path/to/repository
278 278 """
279 279 Repository.create_manage_file(file, **opts)
280 280
281 281
282 282 @with_engine
283 283 def compare_model_to_db(url, repository, model, **opts):
284 284 """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
285 285
286 286 Compare the current model (assumed to be a module level variable
287 287 of type sqlalchemy.MetaData) against the current database.
288 288
289 289 NOTE: This is EXPERIMENTAL.
290 290 """ # TODO: get rid of EXPERIMENTAL label
291 291 engine = opts.pop('engine')
292 292 return ControlledSchema.compare_model_to_db(engine, model, repository)
293 293
294 294
295 295 @with_engine
296 296 def create_model(url, repository, **opts):
297 297 """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
298 298
299 299 Dump the current database as a Python model to stdout.
300 300
301 301 NOTE: This is EXPERIMENTAL.
302 302 """ # TODO: get rid of EXPERIMENTAL label
303 303 engine = opts.pop('engine')
304 304 declarative = opts.get('declarative', False)
305 305 return ControlledSchema.create_model(engine, repository, declarative)
306 306
307 307
308 308 @catch_known_errors
309 309 @with_engine
310 310 def make_update_script_for_model(url, repository, oldmodel, model, **opts):
311 311 """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
312 312
313 313 Create a script changing the old Python model to the new (current)
314 314 Python model, sending to stdout.
315 315
316 316 NOTE: This is EXPERIMENTAL.
317 317 """ # TODO: get rid of EXPERIMENTAL label
318 318 engine = opts.pop('engine')
319 319 return PythonScript.make_update_script_for_model(
320 320 engine, oldmodel, model, repository, **opts)
321 321
322 322
323 323 @with_engine
324 324 def update_db_from_model(url, repository, model, **opts):
325 325 """%prog update_db_from_model URL REPOSITORY_PATH MODEL
326 326
327 327 Modify the database to match the structure of the current Python
328 328 model. This also sets the db_version number to the latest in the
329 329 repository.
330 330
331 331 NOTE: This is EXPERIMENTAL.
332 332 """ # TODO: get rid of EXPERIMENTAL label
333 333 engine = opts.pop('engine')
334 334 schema = ControlledSchema(engine, repository)
335 335 schema.update_db_from_model(model)
336 336
337 337 @with_engine
338 338 def _migrate(url, repository, version, upgrade, err, **opts):
339 339 engine = opts.pop('engine')
340 340 url = str(engine.url)
341 341 schema = ControlledSchema(engine, repository)
342 342 version = _migrate_version(schema, version, upgrade, err)
343 343
344 344 changeset = schema.changeset(version)
345 345 for ver, change in changeset:
346 346 nextver = ver + changeset.step
347 347 log.info('%s -> %s... ', ver, nextver)
348 348
349 349 if opts.get('preview_sql'):
350 350 if isinstance(change, PythonScript):
351 351 log.info(change.preview_sql(url, changeset.step, **opts))
352 352 elif isinstance(change, SqlScript):
353 353 log.info(change.source())
354 354
355 355 elif opts.get('preview_py'):
356 356 if not isinstance(change, PythonScript):
357 357 raise exceptions.UsageError("Python source can be only displayed"
358 358 " for python migration files")
359 359 source_ver = max(ver, nextver)
360 360 module = schema.repository.version(source_ver).script().module
361 361 funcname = upgrade and "upgrade" or "downgrade"
362 362 func = getattr(module, funcname)
363 363 log.info(inspect.getsource(func))
364 364 else:
365 365 schema.runchange(ver, change, changeset.step)
366 366 log.info('done')
367 367
368 368
369 369 def _migrate_version(schema, version, upgrade, err):
370 370 if version is None:
371 371 return version
372 372 # Version is specified: ensure we're upgrading in the right direction
373 373 # (current version < target version for upgrading; reverse for down)
374 374 version = VerNum(version)
375 375 cur = schema.version
376 376 if upgrade is not None:
377 377 if upgrade:
378 378 direction = cur <= version
379 379 else:
380 380 direction = cur >= version
381 381 if not direction:
382 382 raise exceptions.KnownError(err % (cur, version))
383 383 return version
@@ -1,253 +1,285 b''
1 1 """
2 Code to generate a Python model from a database or differences
3 between a model and database.
2 Code to generate a Python model from a database or differences
3 between a model and database.
4 4
5 Some of this is borrowed heavily from the AutoCode project at:
6 http://code.google.com/p/sqlautocode/
5 Some of this is borrowed heavily from the AutoCode project at:
6 http://code.google.com/p/sqlautocode/
7 7 """
8 8
9 9 import sys
10 10 import logging
11 11
12 12 import sqlalchemy
13 13
14 14 from rhodecode.lib.dbmigrate import migrate
15 15 from rhodecode.lib.dbmigrate.migrate import changeset
16 16
17
17 18 log = logging.getLogger(__name__)
18 19 HEADER = """
19 20 ## File autogenerated by genmodel.py
20 21
21 22 from sqlalchemy import *
22 23 meta = MetaData()
23 24 """
24 25
25 26 DECLARATIVE_HEADER = """
26 27 ## File autogenerated by genmodel.py
27 28
28 29 from sqlalchemy import *
29 30 from sqlalchemy.ext import declarative
30 31
31 32 Base = declarative.declarative_base()
32 33 """
33 34
34 35
35 36 class ModelGenerator(object):
37 """Various transformations from an A, B diff.
38
39 In the implementation, A tends to be called the model and B
40 the database (although this is not true of all diffs).
41 The diff is directionless, but transformations apply the diff
42 in a particular direction, described in the method name.
43 """
36 44
37 45 def __init__(self, diff, engine, declarative=False):
38 46 self.diff = diff
39 47 self.engine = engine
40 48 self.declarative = declarative
41 49
42 50 def column_repr(self, col):
43 51 kwarg = []
44 52 if col.key != col.name:
45 53 kwarg.append('key')
46 54 if col.primary_key:
47 55 col.primary_key = True # otherwise it dumps it as 1
48 56 kwarg.append('primary_key')
49 57 if not col.nullable:
50 58 kwarg.append('nullable')
51 59 if col.onupdate:
52 60 kwarg.append('onupdate')
53 61 if col.default:
54 62 if col.primary_key:
55 63 # I found that PostgreSQL automatically creates a
56 64 # default value for the sequence, but let's not show
57 65 # that.
58 66 pass
59 67 else:
60 68 kwarg.append('default')
61 ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
69 args = ['%s=%r' % (k, getattr(col, k)) for k in kwarg]
62 70
63 71 # crs: not sure if this is good idea, but it gets rid of extra
64 72 # u''
65 73 name = col.name.encode('utf8')
66 74
67 75 type_ = col.type
68 76 for cls in col.type.__class__.__mro__:
69 77 if cls.__module__ == 'sqlalchemy.types' and \
70 78 not cls.__name__.isupper():
71 79 if cls is not type_.__class__:
72 80 type_ = cls()
73 81 break
74 82
83 type_repr = repr(type_)
84 if type_repr.endswith('()'):
85 type_repr = type_repr[:-2]
86
87 constraints = [repr(cn) for cn in col.constraints]
88
75 89 data = {
76 90 'name': name,
77 'type': type_,
78 'constraints': ', '.join([repr(cn) for cn in col.constraints]),
79 'args': ks and ks or ''}
91 'commonStuff': ', '.join([type_repr] + constraints + args),
92 }
80 93
81 if data['constraints']:
82 if data['args']:
83 data['args'] = ',' + data['args']
84
85 if data['constraints'] or data['args']:
86 data['maybeComma'] = ','
94 if self.declarative:
95 return """%(name)s = Column(%(commonStuff)s)""" % data
87 96 else:
88 data['maybeComma'] = ''
97 return """Column(%(name)r, %(commonStuff)s)""" % data
89 98
90 commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
91 commonStuff = commonStuff.strip()
92 data['commonStuff'] = commonStuff
93 if self.declarative:
94 return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
95 else:
96 return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
97
98 def getTableDefn(self, table):
99 def _getTableDefn(self, table, metaName='meta'):
99 100 out = []
100 101 tableName = table.name
101 102 if self.declarative:
102 103 out.append("class %(table)s(Base):" % {'table': tableName})
103 out.append(" __tablename__ = '%(table)s'" % {'table': tableName})
104 out.append(" __tablename__ = '%(table)s'\n" %
105 {'table': tableName})
104 106 for col in table.columns:
105 out.append(" %s" % self.column_repr(col))
107 out.append(" %s" % self.column_repr(col))
108 out.append('\n')
106 109 else:
107 out.append("%(table)s = Table('%(table)s', meta," % \
108 {'table': tableName})
110 out.append("%(table)s = Table('%(table)s', %(meta)s," %
111 {'table': tableName, 'meta': metaName})
109 112 for col in table.columns:
110 out.append(" %s," % self.column_repr(col))
111 out.append(")")
113 out.append(" %s," % self.column_repr(col))
114 out.append(")\n")
112 115 return out
113 116
114 117 def _get_tables(self,missingA=False,missingB=False,modified=False):
115 118 to_process = []
116 119 for bool_,names,metadata in (
117 120 (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
118 121 (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
119 122 (modified,self.diff.tables_different,self.diff.metadataA),
120 123 ):
121 124 if bool_:
122 125 for name in names:
123 126 yield metadata.tables.get(name)
124 127
125 def toPython(self):
126 """Assume database is current and model is empty."""
128 def genBDefinition(self):
129 """Generates the source code for a definition of B.
130
131 Assumes a diff where A is empty.
132
133 Was: toPython. Assume database (B) is current and model (A) is empty.
134 """
135
127 136 out = []
128 137 if self.declarative:
129 138 out.append(DECLARATIVE_HEADER)
130 139 else:
131 140 out.append(HEADER)
132 141 out.append("")
133 142 for table in self._get_tables(missingA=True):
134 out.extend(self.getTableDefn(table))
135 out.append("")
143 out.extend(self._getTableDefn(table))
136 144 return '\n'.join(out)
137 145
138 def toUpgradeDowngradePython(self, indent=' '):
139 ''' Assume model is most current and database is out-of-date. '''
140 decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema',
141 'meta = MetaData()']
142 for table in self._get_tables(
143 missingA=True,missingB=True,modified=True
144 ):
145 decls.extend(self.getTableDefn(table))
146 def genB2AMigration(self, indent=' '):
147 '''Generate a migration from B to A.
148
149 Was: toUpgradeDowngradePython
150 Assume model (A) is most current and database (B) is out-of-date.
151 '''
152
153 decls = ['from migrate.changeset import schema',
154 'pre_meta = MetaData()',
155 'post_meta = MetaData()',
156 ]
157 upgradeCommands = ['pre_meta.bind = migrate_engine',
158 'post_meta.bind = migrate_engine']
159 downgradeCommands = list(upgradeCommands)
160
161 for tn in self.diff.tables_missing_from_A:
162 pre_table = self.diff.metadataB.tables[tn]
163 decls.extend(self._getTableDefn(pre_table, metaName='pre_meta'))
164 upgradeCommands.append(
165 "pre_meta.tables[%(table)r].drop()" % {'table': tn})
166 downgradeCommands.append(
167 "pre_meta.tables[%(table)r].create()" % {'table': tn})
146 168
147 upgradeCommands, downgradeCommands = [], []
148 for tableName in self.diff.tables_missing_from_A:
149 upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
150 downgradeCommands.append("%(table)s.create()" % \
151 {'table': tableName})
152 for tableName in self.diff.tables_missing_from_B:
153 upgradeCommands.append("%(table)s.create()" % {'table': tableName})
154 downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
169 for tn in self.diff.tables_missing_from_B:
170 post_table = self.diff.metadataA.tables[tn]
171 decls.extend(self._getTableDefn(post_table, metaName='post_meta'))
172 upgradeCommands.append(
173 "post_meta.tables[%(table)r].create()" % {'table': tn})
174 downgradeCommands.append(
175 "post_meta.tables[%(table)r].drop()" % {'table': tn})
155 176
156 for tableName in self.diff.tables_different:
157 dbTable = self.diff.metadataB.tables[tableName]
158 missingInDatabase, missingInModel, diffDecl = \
159 self.diff.colDiffs[tableName]
160 for col in missingInDatabase:
161 upgradeCommands.append('%s.columns[%r].create()' % (
162 modelTable, col.name))
163 downgradeCommands.append('%s.columns[%r].drop()' % (
164 modelTable, col.name))
165 for col in missingInModel:
166 upgradeCommands.append('%s.columns[%r].drop()' % (
167 modelTable, col.name))
168 downgradeCommands.append('%s.columns[%r].create()' % (
169 modelTable, col.name))
170 for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
177 for (tn, td) in self.diff.tables_different.iteritems():
178 if td.columns_missing_from_A or td.columns_different:
179 pre_table = self.diff.metadataB.tables[tn]
180 decls.extend(self._getTableDefn(
181 pre_table, metaName='pre_meta'))
182 if td.columns_missing_from_B or td.columns_different:
183 post_table = self.diff.metadataA.tables[tn]
184 decls.extend(self._getTableDefn(
185 post_table, metaName='post_meta'))
186
187 for col in td.columns_missing_from_A:
188 upgradeCommands.append(
189 'pre_meta.tables[%r].columns[%r].drop()' % (tn, col))
190 downgradeCommands.append(
191 'pre_meta.tables[%r].columns[%r].create()' % (tn, col))
192 for col in td.columns_missing_from_B:
193 upgradeCommands.append(
194 'post_meta.tables[%r].columns[%r].create()' % (tn, col))
195 downgradeCommands.append(
196 'post_meta.tables[%r].columns[%r].drop()' % (tn, col))
197 for modelCol, databaseCol, modelDecl, databaseDecl in td.columns_different:
171 198 upgradeCommands.append(
172 199 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
173 modelTable, modelCol.name, databaseCol.name))
200 tn, modelCol.name, databaseCol.name))
174 201 downgradeCommands.append(
175 202 'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
176 modelTable, modelCol.name, databaseCol.name))
177 pre_command = ' meta.bind = migrate_engine'
203 tn, modelCol.name, databaseCol.name))
178 204
179 205 return (
180 206 '\n'.join(decls),
181 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
182 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
207 '\n'.join('%s%s' % (indent, line) for line in upgradeCommands),
208 '\n'.join('%s%s' % (indent, line) for line in downgradeCommands))
183 209
184 210 def _db_can_handle_this_change(self,td):
211 """Check if the database can handle going from B to A."""
212
185 213 if (td.columns_missing_from_B
186 214 and not td.columns_missing_from_A
187 215 and not td.columns_different):
188 # Even sqlite can handle this.
216 # Even sqlite can handle column additions.
189 217 return True
190 218 else:
191 219 return not self.engine.url.drivername.startswith('sqlite')
192 220
193 def applyModel(self):
194 """Apply model to current database."""
221 def runB2A(self):
222 """Goes from B to A.
223
224 Was: applyModel. Apply model (A) to current database (B).
225 """
195 226
196 227 meta = sqlalchemy.MetaData(self.engine)
197 228
198 229 for table in self._get_tables(missingA=True):
199 230 table = table.tometadata(meta)
200 231 table.drop()
201 232 for table in self._get_tables(missingB=True):
202 233 table = table.tometadata(meta)
203 234 table.create()
204 235 for modelTable in self._get_tables(modified=True):
205 236 tableName = modelTable.name
206 237 modelTable = modelTable.tometadata(meta)
207 238 dbTable = self.diff.metadataB.tables[tableName]
208 239
209 240 td = self.diff.tables_different[tableName]
210 241
211 242 if self._db_can_handle_this_change(td):
212 243
213 244 for col in td.columns_missing_from_B:
214 245 modelTable.columns[col].create()
215 246 for col in td.columns_missing_from_A:
216 247 dbTable.columns[col].drop()
217 248 # XXX handle column changes here.
218 249 else:
219 250 # Sqlite doesn't support drop column, so you have to
220 251 # do more: create temp table, copy data to it, drop
221 252 # old table, create new table, copy data back.
222 253 #
223 254 # I wonder if this is guaranteed to be unique?
224 255 tempName = '_temp_%s' % modelTable.name
225 256
226 257 def getCopyStatement():
227 258 preparer = self.engine.dialect.preparer
228 259 commonCols = []
229 260 for modelCol in modelTable.columns:
230 261 if modelCol.name in dbTable.columns:
231 262 commonCols.append(modelCol.name)
232 263 commonColsStr = ', '.join(commonCols)
233 264 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
234 265 (tableName, commonColsStr, commonColsStr, tempName)
235 266
236 267 # Move the data in one transaction, so that we don't
237 268 # leave the database in a nasty state.
238 269 connection = self.engine.connect()
239 270 trans = connection.begin()
240 271 try:
241 272 connection.execute(
242 273 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
243 274 (tempName, modelTable.name))
244 275 # make sure the drop takes place inside our
245 276 # transaction with the bind parameter
246 277 modelTable.drop(bind=connection)
247 278 modelTable.create(bind=connection)
248 279 connection.execute(getCopyStatement())
249 280 connection.execute('DROP TABLE %s' % tempName)
250 281 trans.commit()
251 282 except:
252 283 trans.rollback()
253 284 raise
285
@@ -1,231 +1,242 b''
1 1 """
2 2 SQLAlchemy migrate repository management.
3 3 """
4 4 import os
5 5 import shutil
6 6 import string
7 7 import logging
8 8
9 9 from pkg_resources import resource_filename
10 10 from tempita import Template as TempitaTemplate
11 11
12 12 from rhodecode.lib.dbmigrate.migrate import exceptions
13 13 from rhodecode.lib.dbmigrate.migrate.versioning import version, pathed, cfgparse
14 14 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
15 15 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
16 16
17 17
18 18 log = logging.getLogger(__name__)
19 19
20 20 class Changeset(dict):
21 21 """A collection of changes to be applied to a database.
22 22
23 23 Changesets are bound to a repository and manage a set of
24 24 scripts from that repository.
25 25
26 26 Behaves like a dict, for the most part. Keys are ordered based on step value.
27 27 """
28 28
29 29 def __init__(self, start, *changes, **k):
30 30 """
31 31 Give a start version; step must be explicitly stated.
32 32 """
33 33 self.step = k.pop('step', 1)
34 34 self.start = version.VerNum(start)
35 35 self.end = self.start
36 36 for change in changes:
37 37 self.add(change)
38 38
39 39 def __iter__(self):
40 40 return iter(self.items())
41 41
42 42 def keys(self):
43 43 """
44 44 In a series of upgrades x -> y, keys are version x. Sorted.
45 45 """
46 46 ret = super(Changeset, self).keys()
47 47 # Reverse order if downgrading
48 48 ret.sort(reverse=(self.step < 1))
49 49 return ret
50 50
51 51 def values(self):
52 52 return [self[k] for k in self.keys()]
53 53
54 54 def items(self):
55 55 return zip(self.keys(), self.values())
56 56
57 57 def add(self, change):
58 58 """Add new change to changeset"""
59 59 key = self.end
60 60 self.end += self.step
61 61 self[key] = change
62 62
63 63 def run(self, *p, **k):
64 64 """Run the changeset scripts"""
65 65 for version, script in self:
66 66 script.run(*p, **k)
67 67
68 68
69 69 class Repository(pathed.Pathed):
70 70 """A project's change script repository"""
71 71
72 72 _config = 'migrate.cfg'
73 73 _versions = 'versions'
74 74
75 75 def __init__(self, path):
76 76 log.debug('Loading repository %s...' % path)
77 77 self.verify(path)
78 78 super(Repository, self).__init__(path)
79 79 self.config = cfgparse.Config(os.path.join(self.path, self._config))
80 80 self.versions = version.Collection(os.path.join(self.path,
81 81 self._versions))
82 82 log.debug('Repository %s loaded successfully' % path)
83 83 log.debug('Config: %r' % self.config.to_dict())
84 84
85 85 @classmethod
86 86 def verify(cls, path):
87 87 """
88 88 Ensure the target path is a valid repository.
89 89
90 90 :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
91 91 """
92 92 # Ensure the existence of required files
93 93 try:
94 94 cls.require_found(path)
95 95 cls.require_found(os.path.join(path, cls._config))
96 96 cls.require_found(os.path.join(path, cls._versions))
97 97 except exceptions.PathNotFoundError, e:
98 98 raise exceptions.InvalidRepositoryError(path)
99 99
100 100 @classmethod
101 101 def prepare_config(cls, tmpl_dir, name, options=None):
102 102 """
103 103 Prepare a project configuration file for a new project.
104 104
105 105 :param tmpl_dir: Path to Repository template
106 106 :param config_file: Name of the config file in Repository template
107 107 :param name: Repository name
108 108 :type tmpl_dir: string
109 109 :type config_file: string
110 110 :type name: string
111 111 :returns: Populated config file
112 112 """
113 113 if options is None:
114 114 options = {}
115 115 options.setdefault('version_table', 'migrate_version')
116 116 options.setdefault('repository_id', name)
117 117 options.setdefault('required_dbs', [])
118 options.setdefault('use_timestamp_numbering', '0')
118 119
119 120 tmpl = open(os.path.join(tmpl_dir, cls._config)).read()
120 121 ret = TempitaTemplate(tmpl).substitute(options)
121 122
122 123 # cleanup
123 124 del options['__template_name__']
124 125
125 126 return ret
126 127
127 128 @classmethod
128 129 def create(cls, path, name, **opts):
129 130 """Create a repository at a specified path"""
130 131 cls.require_notfound(path)
131 132 theme = opts.pop('templates_theme', None)
132 133 t_path = opts.pop('templates_path', None)
133 134
134 135 # Create repository
135 136 tmpl_dir = Template(t_path).get_repository(theme=theme)
136 137 shutil.copytree(tmpl_dir, path)
137 138
138 139 # Edit config defaults
139 140 config_text = cls.prepare_config(tmpl_dir, name, options=opts)
140 141 fd = open(os.path.join(path, cls._config), 'w')
141 142 fd.write(config_text)
142 143 fd.close()
143 144
144 145 opts['repository_name'] = name
145 146
146 147 # Create a management script
147 148 manager = os.path.join(path, 'manage.py')
148 149 Repository.create_manage_file(manager, templates_theme=theme,
149 150 templates_path=t_path, **opts)
150 151
151 152 return cls(path)
152 153
153 154 def create_script(self, description, **k):
154 155 """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
156
157 k['use_timestamp_numbering'] = self.use_timestamp_numbering
155 158 self.versions.create_new_python_version(description, **k)
156 159
157 def create_script_sql(self, database, **k):
160 def create_script_sql(self, database, description, **k):
158 161 """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
159 self.versions.create_new_sql_version(database, **k)
162 k['use_timestamp_numbering'] = self.use_timestamp_numbering
163 self.versions.create_new_sql_version(database, description, **k)
160 164
161 165 @property
162 166 def latest(self):
163 167 """API to :attr:`migrate.versioning.version.Collection.latest`"""
164 168 return self.versions.latest
165 169
166 170 @property
167 171 def version_table(self):
168 172 """Returns version_table name specified in config"""
169 173 return self.config.get('db_settings', 'version_table')
170 174
171 175 @property
172 176 def id(self):
173 177 """Returns repository id specified in config"""
174 178 return self.config.get('db_settings', 'repository_id')
175 179
180 @property
181 def use_timestamp_numbering(self):
182 """Returns use_timestamp_numbering specified in config"""
183 ts_numbering = self.config.get('db_settings', 'use_timestamp_numbering', raw=True)
184
185 return ts_numbering
186
176 187 def version(self, *p, **k):
177 188 """API to :attr:`migrate.versioning.version.Collection.version`"""
178 189 return self.versions.version(*p, **k)
179 190
180 191 @classmethod
181 192 def clear(cls):
182 193 # TODO: deletes repo
183 194 super(Repository, cls).clear()
184 195 version.Collection.clear()
185 196
186 197 def changeset(self, database, start, end=None):
187 198 """Create a changeset to migrate this database from ver. start to end/latest.
188 199
189 200 :param database: name of database to generate changeset
190 201 :param start: version to start at
191 202 :param end: version to end at (latest if None given)
192 203 :type database: string
193 204 :type start: int
194 205 :type end: int
195 206 :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
196 207 """
197 208 start = version.VerNum(start)
198 209
199 210 if end is None:
200 211 end = self.latest
201 212 else:
202 213 end = version.VerNum(end)
203 214
204 215 if start <= end:
205 216 step = 1
206 217 range_mod = 1
207 218 op = 'upgrade'
208 219 else:
209 220 step = -1
210 221 range_mod = 0
211 222 op = 'downgrade'
212 223
213 224 versions = range(start + range_mod, end + range_mod, step)
214 225 changes = [self.version(v).script(database, op) for v in versions]
215 226 ret = Changeset(start, step=step, *changes)
216 227 return ret
217 228
218 229 @classmethod
219 230 def create_manage_file(cls, file_, **opts):
220 231 """Create a project management script (manage.py)
221 232
222 233 :param file_: Destination file to be written
223 234 :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
224 235 """
225 236 mng_file = Template(opts.pop('templates_path', None))\
226 237 .get_manage(theme=opts.pop('templates_theme', None))
227 238
228 239 tmpl = open(mng_file).read()
229 240 fd = open(file_, 'w')
230 241 fd.write(TempitaTemplate(tmpl).substitute(opts))
231 242 fd.close()
@@ -1,213 +1,220 b''
1 1 """
2 2 Database schema version management.
3 3 """
4 4 import sys
5 5 import logging
6 6
7 7 from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
8 8 create_engine)
9 9 from sqlalchemy.sql import and_
10 10 from sqlalchemy import exceptions as sa_exceptions
11 11 from sqlalchemy.sql import bindparam
12 12
13 13 from rhodecode.lib.dbmigrate.migrate import exceptions
14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_07
14 15 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
15 16 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
16 17 from rhodecode.lib.dbmigrate.migrate.versioning.util import load_model
17 18 from rhodecode.lib.dbmigrate.migrate.versioning.version import VerNum
18 19
19 20
20 21 log = logging.getLogger(__name__)
21 22
22 23 class ControlledSchema(object):
23 24 """A database under version control"""
24 25
25 26 def __init__(self, engine, repository):
26 27 if isinstance(repository, basestring):
27 28 repository = Repository(repository)
28 29 self.engine = engine
29 30 self.repository = repository
30 31 self.meta = MetaData(engine)
31 32 self.load()
32 33
33 34 def __eq__(self, other):
34 35 """Compare two schemas by repositories and versions"""
35 36 return (self.repository is other.repository \
36 37 and self.version == other.version)
37 38
38 39 def load(self):
39 40 """Load controlled schema version info from DB"""
40 41 tname = self.repository.version_table
41 42 try:
42 43 if not hasattr(self, 'table') or self.table is None:
43 44 self.table = Table(tname, self.meta, autoload=True)
44 45
45 46 result = self.engine.execute(self.table.select(
46 47 self.table.c.repository_id == str(self.repository.id)))
47 48
48 49 data = list(result)[0]
49 50 except:
50 51 cls, exc, tb = sys.exc_info()
51 52 raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
52 53
53 54 self.version = data['version']
54 55 return data
55 56
56 57 def drop(self):
57 58 """
58 59 Remove version control from a database.
59 60 """
60 try:
61 self.table.drop()
62 except (sa_exceptions.SQLError):
63 raise exceptions.DatabaseNotControlledError(str(self.table))
61 if SQLA_07:
62 try:
63 self.table.drop()
64 except sa_exceptions.DatabaseError:
65 raise exceptions.DatabaseNotControlledError(str(self.table))
66 else:
67 try:
68 self.table.drop()
69 except (sa_exceptions.SQLError):
70 raise exceptions.DatabaseNotControlledError(str(self.table))
64 71
65 72 def changeset(self, version=None):
66 73 """API to Changeset creation.
67 74
68 75 Uses self.version for start version and engine.name
69 76 to get database name.
70 77 """
71 78 database = self.engine.name
72 79 start_ver = self.version
73 80 changeset = self.repository.changeset(database, start_ver, version)
74 81 return changeset
75 82
76 83 def runchange(self, ver, change, step):
77 84 startver = ver
78 85 endver = ver + step
79 86 # Current database version must be correct! Don't run if corrupt!
80 87 if self.version != startver:
81 88 raise exceptions.InvalidVersionError("%s is not %s" % \
82 89 (self.version, startver))
83 90 # Run the change
84 91 change.run(self.engine, step)
85 92
86 93 # Update/refresh database version
87 94 self.update_repository_table(startver, endver)
88 95 self.load()
89 96
90 97 def update_repository_table(self, startver, endver):
91 98 """Update version_table with new information"""
92 99 update = self.table.update(and_(self.table.c.version == int(startver),
93 100 self.table.c.repository_id == str(self.repository.id)))
94 101 self.engine.execute(update, version=int(endver))
95 102
96 103 def upgrade(self, version=None):
97 104 """
98 105 Upgrade (or downgrade) to a specified version, or latest version.
99 106 """
100 107 changeset = self.changeset(version)
101 108 for ver, change in changeset:
102 109 self.runchange(ver, change, changeset.step)
103 110
104 111 def update_db_from_model(self, model):
105 112 """
106 113 Modify the database to match the structure of the current Python model.
107 114 """
108 115 model = load_model(model)
109 116
110 117 diff = schemadiff.getDiffOfModelAgainstDatabase(
111 118 model, self.engine, excludeTables=[self.repository.version_table]
112 119 )
113 genmodel.ModelGenerator(diff,self.engine).applyModel()
120 genmodel.ModelGenerator(diff,self.engine).runB2A()
114 121
115 122 self.update_repository_table(self.version, int(self.repository.latest))
116 123
117 124 self.load()
118 125
119 126 @classmethod
120 127 def create(cls, engine, repository, version=None):
121 128 """
122 129 Declare a database to be under a repository's version control.
123 130
124 131 :raises: :exc:`DatabaseAlreadyControlledError`
125 132 :returns: :class:`ControlledSchema`
126 133 """
127 134 # Confirm that the version # is valid: positive, integer,
128 135 # exists in repos
129 136 if isinstance(repository, basestring):
130 137 repository = Repository(repository)
131 138 version = cls._validate_version(repository, version)
132 139 table = cls._create_table_version(engine, repository, version)
133 140 # TODO: history table
134 141 # Load repository information and return
135 142 return cls(engine, repository)
136 143
137 144 @classmethod
138 145 def _validate_version(cls, repository, version):
139 146 """
140 147 Ensures this is a valid version number for this repository.
141 148
142 149 :raises: :exc:`InvalidVersionError` if invalid
143 150 :return: valid version number
144 151 """
145 152 if version is None:
146 153 version = 0
147 154 try:
148 155 version = VerNum(version) # raises valueerror
149 156 if version < 0 or version > repository.latest:
150 157 raise ValueError()
151 158 except ValueError:
152 159 raise exceptions.InvalidVersionError(version)
153 160 return version
154 161
155 162 @classmethod
156 163 def _create_table_version(cls, engine, repository, version):
157 164 """
158 165 Creates the versioning table in a database.
159 166
160 167 :raises: :exc:`DatabaseAlreadyControlledError`
161 168 """
162 169 # Create tables
163 170 tname = repository.version_table
164 171 meta = MetaData(engine)
165 172
166 173 table = Table(
167 174 tname, meta,
168 175 Column('repository_id', String(250), primary_key=True),
169 176 Column('repository_path', Text),
170 177 Column('version', Integer), )
171 178
172 179 # there can be multiple repositories/schemas in the same db
173 180 if not table.exists():
174 181 table.create()
175 182
176 183 # test for existing repository_id
177 184 s = table.select(table.c.repository_id == bindparam("repository_id"))
178 185 result = engine.execute(s, repository_id=repository.id)
179 186 if result.fetchone():
180 187 raise exceptions.DatabaseAlreadyControlledError
181 188
182 189 # Insert data
183 190 engine.execute(table.insert().values(
184 191 repository_id=repository.id,
185 192 repository_path=repository.path,
186 193 version=int(version)))
187 194 return table
188 195
189 196 @classmethod
190 197 def compare_model_to_db(cls, engine, model, repository):
191 198 """
192 199 Compare the current model against the current database.
193 200 """
194 201 if isinstance(repository, basestring):
195 202 repository = Repository(repository)
196 203 model = load_model(model)
197 204
198 205 diff = schemadiff.getDiffOfModelAgainstDatabase(
199 206 model, engine, excludeTables=[repository.version_table])
200 207 return diff
201 208
202 209 @classmethod
203 210 def create_model(cls, engine, repository, declarative=False):
204 211 """
205 212 Dump the current database as a Python model.
206 213 """
207 214 if isinstance(repository, basestring):
208 215 repository = Repository(repository)
209 216
210 217 diff = schemadiff.getDiffOfModelAgainstDatabase(
211 218 MetaData(), engine, excludeTables=[repository.version_table]
212 219 )
213 return genmodel.ModelGenerator(diff, engine, declarative).toPython()
220 return genmodel.ModelGenerator(diff, engine, declarative).genBDefinition()
@@ -1,160 +1,160 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3
4 4 import shutil
5 5 import warnings
6 6 import logging
7 7 import inspect
8 8 from StringIO import StringIO
9 9
10 10 from rhodecode.lib.dbmigrate import migrate
11 11 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
12 12 from rhodecode.lib.dbmigrate.migrate.versioning.config import operations
13 13 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
14 14 from rhodecode.lib.dbmigrate.migrate.versioning.script import base
15 15 from rhodecode.lib.dbmigrate.migrate.versioning.util import import_path, load_model, with_engine
16 16 from rhodecode.lib.dbmigrate.migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
17 17
18 18 log = logging.getLogger(__name__)
19 19 __all__ = ['PythonScript']
20 20
21 21
22 22 class PythonScript(base.BaseScript):
23 23 """Base for Python scripts"""
24 24
25 25 @classmethod
26 26 def create(cls, path, **opts):
27 27 """Create an empty migration script at specified path
28 28
29 29 :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
30 30 cls.require_notfound(path)
31 31
32 32 src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
33 33 shutil.copy(src, path)
34 34
35 35 return cls(path)
36 36
37 37 @classmethod
38 38 def make_update_script_for_model(cls, engine, oldmodel,
39 39 model, repository, **opts):
40 40 """Create a migration script based on difference between two SA models.
41 41
42 42 :param repository: path to migrate repository
43 43 :param oldmodel: dotted.module.name:SAClass or SAClass object
44 44 :param model: dotted.module.name:SAClass or SAClass object
45 45 :param engine: SQLAlchemy engine
46 46 :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
47 47 :type oldmodel: string or Class
48 48 :type model: string or Class
49 49 :type engine: Engine instance
50 50 :returns: Upgrade / Downgrade script
51 51 :rtype: string
52 52 """
53 53
54 54 if isinstance(repository, basestring):
55 55 # oh dear, an import cycle!
56 56 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
57 57 repository = Repository(repository)
58 58
59 59 oldmodel = load_model(oldmodel)
60 60 model = load_model(model)
61 61
62 62 # Compute differences.
63 63 diff = schemadiff.getDiffOfModelAgainstModel(
64 model,
64 65 oldmodel,
65 model,
66 66 excludeTables=[repository.version_table])
67 67 # TODO: diff can be False (there is no difference?)
68 68 decls, upgradeCommands, downgradeCommands = \
69 genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython()
69 genmodel.ModelGenerator(diff,engine).genB2AMigration()
70 70
71 71 # Store differences into file.
72 72 src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
73 73 f = open(src)
74 74 contents = f.read()
75 75 f.close()
76 76
77 77 # generate source
78 78 search = 'def upgrade(migrate_engine):'
79 79 contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
80 80 if upgradeCommands:
81 81 contents = contents.replace(' pass', upgradeCommands, 1)
82 82 if downgradeCommands:
83 83 contents = contents.replace(' pass', downgradeCommands, 1)
84 84 return contents
85 85
86 86 @classmethod
87 87 def verify_module(cls, path):
88 88 """Ensure path is a valid script
89 89
90 90 :param path: Script location
91 91 :type path: string
92 92 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
93 93 :returns: Python module
94 94 """
95 95 # Try to import and get the upgrade() func
96 96 module = import_path(path)
97 97 try:
98 98 assert callable(module.upgrade)
99 99 except Exception, e:
100 100 raise InvalidScriptError(path + ': %s' % str(e))
101 101 return module
102 102
103 103 def preview_sql(self, url, step, **args):
104 104 """Mocks SQLAlchemy Engine to store all executed calls in a string
105 105 and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
106 106
107 107 :returns: SQL file
108 108 """
109 109 buf = StringIO()
110 110 args['engine_arg_strategy'] = 'mock'
111 111 args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
112 112
113 113 @with_engine
114 114 def go(url, step, **kw):
115 115 engine = kw.pop('engine')
116 116 self.run(engine, step)
117 117 return buf.getvalue()
118 118
119 119 return go(url, step, **args)
120 120
121 121 def run(self, engine, step):
122 122 """Core method of Script file.
123 123 Exectues :func:`update` or :func:`downgrade` functions
124 124
125 125 :param engine: SQLAlchemy Engine
126 126 :param step: Operation to run
127 127 :type engine: string
128 128 :type step: int
129 129 """
130 130 if step > 0:
131 131 op = 'upgrade'
132 132 elif step < 0:
133 133 op = 'downgrade'
134 134 else:
135 135 raise ScriptError("%d is not a valid step" % step)
136 136
137 137 funcname = base.operations[op]
138 138 script_func = self._func(funcname)
139 139
140 140 # check for old way of using engine
141 141 if not inspect.getargspec(script_func)[0]:
142 142 raise TypeError("upgrade/downgrade functions must accept engine"
143 143 " parameter (since version 0.5.4)")
144 144
145 145 script_func(engine)
146 146
147 147 @property
148 148 def module(self):
149 149 """Calls :meth:`migrate.versioning.script.py.verify_module`
150 150 and returns it.
151 151 """
152 152 if not hasattr(self, '_module'):
153 153 self._module = self.verify_module(self.path)
154 154 return self._module
155 155
156 156 def _func(self, funcname):
157 157 if not hasattr(self.module, funcname):
158 158 msg = "Function '%s' is not defined in this script"
159 159 raise ScriptError(msg % funcname)
160 160 return getattr(self.module, funcname)
@@ -1,20 +1,25 b''
1 1 [db_settings]
2 2 # Used to identify which repository this database is versioned under.
3 3 # You can use the name of your project.
4 4 repository_id={{ locals().pop('repository_id') }}
5 5
6 6 # The name of the database table used to track the schema version.
7 7 # This name shouldn't already be used by your project.
8 8 # If this is changed once a database is under version control, you'll need to
9 9 # change the table name in each database too.
10 10 version_table={{ locals().pop('version_table') }}
11 11
12 12 # When committing a change script, Migrate will attempt to generate the
13 13 # sql for all supported databases; normally, if one of them fails - probably
14 14 # because you don't have that database installed - it is ignored and the
15 15 # commit continues, perhaps ending successfully.
16 16 # Databases in this list MUST compile successfully during a commit, or the
17 17 # entire commit will fail. List the databases your application will actually
18 18 # be using to ensure your updates to that database work properly.
19 19 # This must be a list; example: ['postgres','sqlite']
20 20 required_dbs={{ locals().pop('required_dbs') }}
21
22 # When creating new change scripts, Migrate will stamp the new script with
23 # a version number. By default this is latest_version + 1. You can set this
24 # to 'true' to tell Migrate to use the UTC timestamp instead.
25 use_timestamp_numbering='false' No newline at end of file
@@ -1,215 +1,240 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3
4 4 import os
5 5 import re
6 6 import shutil
7 7 import logging
8 8
9 9 from rhodecode.lib.dbmigrate.migrate import exceptions
10 10 from rhodecode.lib.dbmigrate.migrate.versioning import pathed, script
11 from datetime import datetime
11 12
12 13
13 14 log = logging.getLogger(__name__)
14 15
15 16 class VerNum(object):
16 17 """A version number that behaves like a string and int at the same time"""
17 18
18 19 _instances = dict()
19 20
20 21 def __new__(cls, value):
21 22 val = str(value)
22 23 if val not in cls._instances:
23 24 cls._instances[val] = super(VerNum, cls).__new__(cls)
24 25 ret = cls._instances[val]
25 26 return ret
26 27
27 28 def __init__(self,value):
28 29 self.value = str(int(value))
29 30 if self < 0:
30 31 raise ValueError("Version number cannot be negative")
31 32
32 33 def __add__(self, value):
33 34 ret = int(self) + int(value)
34 35 return VerNum(ret)
35 36
36 37 def __sub__(self, value):
37 38 return self + (int(value) * -1)
38 39
39 40 def __cmp__(self, value):
40 41 return int(self) - int(value)
41 42
42 43 def __repr__(self):
43 44 return "<VerNum(%s)>" % self.value
44 45
45 46 def __str__(self):
46 47 return str(self.value)
47 48
48 49 def __int__(self):
49 50 return int(self.value)
50 51
51 52
52 53 class Collection(pathed.Pathed):
53 54 """A collection of versioning scripts in a repository"""
54 55
55 56 FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
56 57
57 58 def __init__(self, path):
58 59 """Collect current version scripts in repository
59 60 and store them in self.versions
60 61 """
61 62 super(Collection, self).__init__(path)
62
63
63 64 # Create temporary list of files, allowing skipped version numbers.
64 65 files = os.listdir(path)
65 66 if '1' in files:
66 67 # deprecation
67 68 raise Exception('It looks like you have a repository in the old '
68 69 'format (with directories for each version). '
69 70 'Please convert repository before proceeding.')
70 71
71 72 tempVersions = dict()
72 73 for filename in files:
73 74 match = self.FILENAME_WITH_VERSION.match(filename)
74 75 if match:
75 76 num = int(match.group(1))
76 77 tempVersions.setdefault(num, []).append(filename)
77 78 else:
78 79 pass # Must be a helper file or something, let's ignore it.
79 80
80 81 # Create the versions member where the keys
81 82 # are VerNum's and the values are Version's.
82 83 self.versions = dict()
83 84 for num, files in tempVersions.items():
84 85 self.versions[VerNum(num)] = Version(num, path, files)
85 86
86 87 @property
87 88 def latest(self):
88 89 """:returns: Latest version in Collection"""
89 90 return max([VerNum(0)] + self.versions.keys())
90 91
92 def _next_ver_num(self, use_timestamp_numbering):
93 print use_timestamp_numbering
94 if use_timestamp_numbering == True:
95 print "Creating new timestamp version!"
96 return VerNum(int(datetime.utcnow().strftime('%Y%m%d%H%M%S')))
97 else:
98 return self.latest + 1
99
91 100 def create_new_python_version(self, description, **k):
92 101 """Create Python files for new version"""
93 ver = self.latest + 1
102 ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
94 103 extra = str_to_filename(description)
95 104
96 105 if extra:
97 106 if extra == '_':
98 107 extra = ''
99 108 elif not extra.startswith('_'):
100 109 extra = '_%s' % extra
101 110
102 111 filename = '%03d%s.py' % (ver, extra)
103 112 filepath = self._version_path(filename)
104 113
105 114 script.PythonScript.create(filepath, **k)
106 115 self.versions[ver] = Version(ver, self.path, [filename])
107
108 def create_new_sql_version(self, database, **k):
116
117 def create_new_sql_version(self, database, description, **k):
109 118 """Create SQL files for new version"""
110 ver = self.latest + 1
119 ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
111 120 self.versions[ver] = Version(ver, self.path, [])
112 121
122 extra = str_to_filename(description)
123
124 if extra:
125 if extra == '_':
126 extra = ''
127 elif not extra.startswith('_'):
128 extra = '_%s' % extra
129
113 130 # Create new files.
114 131 for op in ('upgrade', 'downgrade'):
115 filename = '%03d_%s_%s.sql' % (ver, database, op)
132 filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op)
116 133 filepath = self._version_path(filename)
117 134 script.SqlScript.create(filepath, **k)
118 135 self.versions[ver].add_script(filepath)
119
136
120 137 def version(self, vernum=None):
121 138 """Returns latest Version if vernum is not given.
122 139 Otherwise, returns wanted version"""
123 140 if vernum is None:
124 141 vernum = self.latest
125 142 return self.versions[VerNum(vernum)]
126 143
127 144 @classmethod
128 145 def clear(cls):
129 146 super(Collection, cls).clear()
130 147
131 148 def _version_path(self, ver):
132 149 """Returns path of file in versions repository"""
133 150 return os.path.join(self.path, str(ver))
134 151
135 152
136 153 class Version(object):
137 154 """A single version in a collection
138 :param vernum: Version Number
155 :param vernum: Version Number
139 156 :param path: Path to script files
140 157 :param filelist: List of scripts
141 158 :type vernum: int, VerNum
142 159 :type path: string
143 160 :type filelist: list
144 161 """
145 162
146 163 def __init__(self, vernum, path, filelist):
147 164 self.version = VerNum(vernum)
148 165
149 166 # Collect scripts in this folder
150 167 self.sql = dict()
151 168 self.python = None
152 169
153 170 for script in filelist:
154 171 self.add_script(os.path.join(path, script))
155
172
156 173 def script(self, database=None, operation=None):
157 174 """Returns SQL or Python Script"""
158 175 for db in (database, 'default'):
159 176 # Try to return a .sql script first
160 177 try:
161 178 return self.sql[db][operation]
162 179 except KeyError:
163 180 continue # No .sql script exists
164 181
165 182 # TODO: maybe add force Python parameter?
166 183 ret = self.python
167 184
168 185 assert ret is not None, \
169 186 "There is no script for %d version" % self.version
170 187 return ret
171 188
172 189 def add_script(self, path):
173 190 """Add script to Collection/Version"""
174 191 if path.endswith(Extensions.py):
175 192 self._add_script_py(path)
176 193 elif path.endswith(Extensions.sql):
177 194 self._add_script_sql(path)
178 195
179 SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
196 SQL_FILENAME = re.compile(r'^.*\.sql')
180 197
181 198 def _add_script_sql(self, path):
182 199 basename = os.path.basename(path)
183 200 match = self.SQL_FILENAME.match(basename)
184
201
185 202 if match:
186 version, dbms, op = match.group(1), match.group(2), match.group(3)
203 basename = basename.replace('.sql', '')
204 parts = basename.split('_')
205 if len(parts) < 3:
206 raise exceptions.ScriptError(
207 "Invalid SQL script name %s " % basename + \
208 "(needs to be ###_description_database_operation.sql)")
209 version = parts[0]
210 op = parts[-1]
211 dbms = parts[-2]
187 212 else:
188 213 raise exceptions.ScriptError(
189 214 "Invalid SQL script name %s " % basename + \
190 "(needs to be ###_database_operation.sql)")
215 "(needs to be ###_description_database_operation.sql)")
191 216
192 217 # File the script into a dictionary
193 218 self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
194 219
195 220 def _add_script_py(self, path):
196 221 if self.python is not None:
197 222 raise exceptions.ScriptError('You can only have one Python script '
198 223 'per version, but you have: %s and %s' % (self.python, path))
199 224 self.python = script.PythonScript(path)
200 225
201 226
202 227 class Extensions:
203 228 """A namespace for file extensions"""
204 229 py = 'py'
205 230 sql = 'sql'
206 231
207 232 def str_to_filename(s):
208 233 """Replaces spaces, (double and single) quotes
209 234 and double underscores to underscores
210 235 """
211 236
212 237 s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
213 238 while '__' in s:
214 239 s = s.replace('__', '_')
215 240 return s
@@ -1,102 +1,116 b''
1 1 import logging
2 2 import datetime
3 3
4 4 from sqlalchemy import *
5 5 from sqlalchemy.exc import DatabaseError
6 6 from sqlalchemy.orm import relation, backref, class_mapper
7 7 from sqlalchemy.orm.session import Session
8 8
9 9 from rhodecode.lib.dbmigrate.migrate import *
10 10 from rhodecode.lib.dbmigrate.migrate.changeset import *
11 11
12 12 from rhodecode.model.meta import Base
13 13
14 14 log = logging.getLogger(__name__)
15 15
16 16 def upgrade(migrate_engine):
17 17 """ Upgrade operations go here.
18 18 Don't create your own engine; bind migrate_engine to your metadata
19 19 """
20 20
21 21 #==========================================================================
22 22 # Add table `groups``
23 23 #==========================================================================
24 24 from rhodecode.model.db import Group
25 25 Group().__table__.create()
26 26
27 27 #==========================================================================
28 28 # Add table `group_to_perm`
29 29 #==========================================================================
30 30 from rhodecode.model.db import GroupToPerm
31 31 GroupToPerm().__table__.create()
32 32
33 33 #==========================================================================
34 34 # Add table `users_groups`
35 35 #==========================================================================
36 36 from rhodecode.model.db import UsersGroup
37 37 UsersGroup().__table__.create()
38 38
39 39 #==========================================================================
40 40 # Add table `users_groups_members`
41 41 #==========================================================================
42 42 from rhodecode.model.db import UsersGroupMember
43 43 UsersGroupMember().__table__.create()
44 44
45 45 #==========================================================================
46 46 # Add table `users_group_repo_to_perm`
47 47 #==========================================================================
48 48 from rhodecode.model.db import UsersGroupRepoToPerm
49 49 UsersGroupRepoToPerm().__table__.create()
50 50
51 51 #==========================================================================
52 52 # Add table `users_group_to_perm`
53 53 #==========================================================================
54 54 from rhodecode.model.db import UsersGroupToPerm
55 55 UsersGroupToPerm().__table__.create()
56 56
57 57 #==========================================================================
58 58 # Upgrade of `users` table
59 59 #==========================================================================
60 60 from rhodecode.model.db import User
61 61
62 62 #add column
63 63 ldap_dn = Column("ldap_dn", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
64 64 ldap_dn.create(User().__table__)
65 65
66 66 api_key = Column("api_key", String(length=255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
67 67 api_key.create(User().__table__)
68 68
69 69 #remove old column
70 70 is_ldap = Column("is_ldap", Boolean(), nullable=False, unique=None, default=False)
71 71 is_ldap.drop(User().__table__)
72 72
73 73
74 74 #==========================================================================
75 75 # Upgrade of `repositories` table
76 76 #==========================================================================
77 77 from rhodecode.model.db import Repository
78 78
79 79 #ADD downloads column#
80 80 enable_downloads = Column("downloads", Boolean(), nullable=True, unique=None, default=True)
81 81 enable_downloads.create(Repository().__table__)
82 82
83 #ADD column created_on
84 created_on = Column('created_on', DateTime(timezone=False), nullable=True,
85 unique=None, default=datetime.datetime.now)
86 created_on.create(Repository().__table__)
87
83 88 #ADD group_id column#
84 89 group_id = Column("group_id", Integer(), ForeignKey('groups.group_id'),
85 90 nullable=True, unique=False, default=None)
86 91
87 92 group_id.create(Repository().__table__)
88 93
89 94
90 95 #ADD clone_uri column#
91 96
92 97 clone_uri = Column("clone_uri", String(length=255, convert_unicode=False,
93 98 assert_unicode=None),
94 99 nullable=True, unique=False, default=None)
95 100
96 101 clone_uri.create(Repository().__table__)
102
103
104 #==========================================================================
105 # Upgrade of `user_followings` table
106 #==========================================================================
107
108 follows_from = Column('follows_from', DateTime(timezone=False), nullable=True, unique=None, default=datetime.datetime.now)
109 follows_from.create(Repository().__table__)
110
97 111 return
98 112
99 113
100 114 def downgrade(migrate_engine):
101 115 meta = MetaData()
102 116 meta.bind = migrate_engine
General Comments 0
You need to be logged in to leave comments. Login now