""" Code to generate a Python model from a database or differences between a model and database. Some of this is borrowed heavily from the AutoCode project at: http://code.google.com/p/sqlautocode/ """ import sys import logging import sqlalchemy import rhodecode.lib.dbmigrate.migrate import rhodecode.lib.dbmigrate.migrate.changeset log = logging.getLogger(__name__) HEADER = """ ## File autogenerated by genmodel.py from sqlalchemy import * """ META_DEFINITION = "meta = MetaData()" DECLARATIVE_DEFINITION = """ from sqlalchemy.ext import declarative Base = declarative.declarative_base() """ class ModelGenerator(object): """Various transformations from an A, B diff. In the implementation, A tends to be called the model and B the database (although this is not true of all diffs). The diff is directionless, but transformations apply the diff in a particular direction, described in the method name. """ def __init__(self, diff, engine, declarative=False): self.diff = diff self.engine = engine self.declarative = declarative def column_repr(self, col): kwarg = [] if col.key != col.name: kwarg.append('key') if col.primary_key: col.primary_key = True # otherwise it dumps it as 1 kwarg.append('primary_key') if not col.nullable: kwarg.append('nullable') if col.onupdate: kwarg.append('onupdate') if col.default: if col.primary_key: # I found that PostgreSQL automatically creates a # default value for the sequence, but let's not show # that. pass else: kwarg.append('default') args = ['%s=%r' % (k, getattr(col, k)) for k in kwarg] # crs: not sure if this is good idea, but it gets rid of extra # u'' name = col.name.encode('utf8') type_ = col.type for cls in col.type.__class__.__mro__: if cls.__module__ == 'sqlalchemy.types' and \ not cls.__name__.isupper(): if cls is not type_.__class__: type_ = cls() break type_repr = repr(type_) if type_repr.endswith('()'): type_repr = type_repr[:-2] constraints = [repr(cn) for cn in col.constraints] data = { 'name': name, 'commonStuff': ', '.join([type_repr] + constraints + args), } if self.declarative: return """%(name)s = Column(%(commonStuff)s)""" % data else: return """Column(%(name)r, %(commonStuff)s)""" % data def _getTableDefn(self, table, metaName='meta'): out = [] tableName = table.name if self.declarative: out.append("class %(table)s(Base):" % {'table': tableName}) out.append(" __tablename__ = '%(table)s'\n" % {'table': tableName}) for col in table.columns: out.append(" %s" % self.column_repr(col)) out.append('\n') else: out.append("%(table)s = Table('%(table)s', %(meta)s," % {'table': tableName, 'meta': metaName}) for col in table.columns: out.append(" %s," % self.column_repr(col)) out.append(")\n") return out def _get_tables(self,missingA=False,missingB=False,modified=False): to_process = [] for bool_,names,metadata in ( (missingA,self.diff.tables_missing_from_A,self.diff.metadataB), (missingB,self.diff.tables_missing_from_B,self.diff.metadataA), (modified,self.diff.tables_different,self.diff.metadataA), ): if bool_: for name in names: yield metadata.tables.get(name) def _genModelHeader(self, tables): out = [] import_index = [] out.append(HEADER) for table in tables: for col in table.columns: if "dialects" in col.type.__module__ and \ col.type.__class__ not in import_index: out.append("from " + col.type.__module__ + " import " + col.type.__class__.__name__) import_index.append(col.type.__class__) out.append("") if self.declarative: out.append(DECLARATIVE_DEFINITION) else: out.append(META_DEFINITION) out.append("") return out def genBDefinition(self): """Generates the source code for a definition of B. Assumes a diff where A is empty. Was: toPython. Assume database (B) is current and model (A) is empty. """ out = [] out.extend(self._genModelHeader(self._get_tables(missingA=True))) for table in self._get_tables(missingA=True): out.extend(self._getTableDefn(table)) return '\n'.join(out) def genB2AMigration(self, indent=' '): """Generate a migration from B to A. Was: toUpgradeDowngradePython Assume model (A) is most current and database (B) is out-of-date. """ decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema', 'pre_meta = MetaData()', 'post_meta = MetaData()', ] upgradeCommands = ['pre_meta.bind = migrate_engine', 'post_meta.bind = migrate_engine'] downgradeCommands = list(upgradeCommands) for tn in self.diff.tables_missing_from_A: pre_table = self.diff.metadataB.tables[tn] decls.extend(self._getTableDefn(pre_table, metaName='pre_meta')) upgradeCommands.append( "pre_meta.tables[%(table)r].drop()" % {'table': tn}) downgradeCommands.append( "pre_meta.tables[%(table)r].create()" % {'table': tn}) for tn in self.diff.tables_missing_from_B: post_table = self.diff.metadataA.tables[tn] decls.extend(self._getTableDefn(post_table, metaName='post_meta')) upgradeCommands.append( "post_meta.tables[%(table)r].create()" % {'table': tn}) downgradeCommands.append( "post_meta.tables[%(table)r].drop()" % {'table': tn}) for (tn, td) in list(self.diff.tables_different.items()): if td.columns_missing_from_A or td.columns_different: pre_table = self.diff.metadataB.tables[tn] decls.extend(self._getTableDefn( pre_table, metaName='pre_meta')) if td.columns_missing_from_B or td.columns_different: post_table = self.diff.metadataA.tables[tn] decls.extend(self._getTableDefn( post_table, metaName='post_meta')) for col in td.columns_missing_from_A: upgradeCommands.append( 'pre_meta.tables[%r].columns[%r].drop()' % (tn, col)) downgradeCommands.append( 'pre_meta.tables[%r].columns[%r].create()' % (tn, col)) for col in td.columns_missing_from_B: upgradeCommands.append( 'post_meta.tables[%r].columns[%r].create()' % (tn, col)) downgradeCommands.append( 'post_meta.tables[%r].columns[%r].drop()' % (tn, col)) for modelCol, databaseCol, modelDecl, databaseDecl in td.columns_different: upgradeCommands.append( 'assert False, "Can\'t alter columns: %s:%s=>%s"' % ( tn, modelCol.name, databaseCol.name)) downgradeCommands.append( 'assert False, "Can\'t alter columns: %s:%s=>%s"' % ( tn, modelCol.name, databaseCol.name)) return ( '\n'.join(decls), '\n'.join('%s%s' % (indent, line) for line in upgradeCommands), '\n'.join('%s%s' % (indent, line) for line in downgradeCommands)) def _db_can_handle_this_change(self,td): """Check if the database can handle going from B to A.""" if (td.columns_missing_from_B and not td.columns_missing_from_A and not td.columns_different): # Even sqlite can handle column additions. return True else: return not self.engine.url.drivername.startswith('sqlite') def runB2A(self): """Goes from B to A. Was: applyModel. Apply model (A) to current database (B). """ meta = sqlalchemy.MetaData(self.engine) for table in self._get_tables(missingA=True): table = table.tometadata(meta) table.drop() for table in self._get_tables(missingB=True): table = table.tometadata(meta) table.create() for modelTable in self._get_tables(modified=True): tableName = modelTable.name modelTable = modelTable.tometadata(meta) dbTable = self.diff.metadataB.tables[tableName] td = self.diff.tables_different[tableName] if self._db_can_handle_this_change(td): for col in td.columns_missing_from_B: modelTable.columns[col].create() for col in td.columns_missing_from_A: dbTable.columns[col].drop() # XXX handle column changes here. else: # Sqlite doesn't support drop column, so you have to # do more: create temp table, copy data to it, drop # old table, create new table, copy data back. # # I wonder if this is guaranteed to be unique? tempName = '_temp_%s' % modelTable.name def getCopyStatement(): preparer = self.engine.dialect.preparer commonCols = [] for modelCol in modelTable.columns: if modelCol.name in dbTable.columns: commonCols.append(modelCol.name) commonColsStr = ', '.join(commonCols) return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \ (tableName, commonColsStr, commonColsStr, tempName) # Move the data in one transaction, so that we don't # leave the database in a nasty state. connection = self.engine.connect() trans = connection.begin() try: connection.execute( 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \ (tempName, modelTable.name)) # make sure the drop takes place inside our # transaction with the bind parameter modelTable.drop(bind=connection) modelTable.create(bind=connection) connection.execute(getCopyStatement()) connection.execute('DROP TABLE %s' % tempName) trans.commit() except: trans.rollback() raise