##// END OF EJS Templates
added dbmigrate package, added model changes...
marcink -
r833:9753e090 beta
parent child Browse files
Show More
@@ -0,0 +1,59 b''
1 # -*- coding: utf-8 -*-
2 """
3 rhodecode.lib.dbmigrate.__init__
4 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5
6 Database migration modules
7
8 :created_on: Dec 11, 2010
9 :author: marcink
10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 :license: GPLv3, see COPYING for more details.
12 """
13 # This program is free software; you can redistribute it and/or
14 # modify it under the terms of the GNU General Public License
15 # as published by the Free Software Foundation; version 2
16 # of the License or (at your opinion) any later version of the license.
17 #
18 # This program is distributed in the hope that it will be useful,
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 # GNU General Public License for more details.
22 #
23 # You should have received a copy of the GNU General Public License
24 # along with this program; if not, write to the Free Software
25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
26 # MA 02110-1301, USA.
27
28 from rhodecode.lib.utils import BasePasterCommand
29 from rhodecode.lib.utils import BasePasterCommand, Command, add_cache
30
31 from sqlalchemy import engine_from_config
32
33 class UpgradeDb(BasePasterCommand):
34 """Command used for paster to upgrade our database to newer version
35 """
36
37 max_args = 1
38 min_args = 1
39
40 usage = "CONFIG_FILE"
41 summary = "Upgrades current db to newer version given configuration file"
42 group_name = "RhodeCode"
43
44 parser = Command.standard_parser(verbose=True)
45
46 def command(self):
47 from pylons import config
48 add_cache(config)
49 engine = engine_from_config(config, 'sqlalchemy.db1.')
50 print engine
51 raise NotImplementedError('Not implemented yet')
52
53
54 def update_parser(self):
55 self.parser.add_option('--sql',
56 action='store_true',
57 dest='just_sql',
58 help="Prints upgrade sql for further investigation",
59 default=False)
@@ -0,0 +1,9 b''
1 """
2 SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for
3 database schema version and repository management and
4 :mod:`migrate.changeset` that allows to define database schema changes
5 using Python.
6 """
7
8 from migrate.versioning import *
9 from migrate.changeset import *
@@ -0,0 +1,28 b''
1 """
2 This module extends SQLAlchemy and provides additional DDL [#]_
3 support.
4
5 .. [#] SQL Data Definition Language
6 """
7 import re
8 import warnings
9
10 import sqlalchemy
11 from sqlalchemy import __version__ as _sa_version
12
13 warnings.simplefilter('always', DeprecationWarning)
14
15 _sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split("."))
16 SQLA_06 = _sa_version >= (0, 6)
17
18 del re
19 del _sa_version
20
21 from migrate.changeset.schema import *
22 from migrate.changeset.constraint import *
23
24 sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
25 sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
26 sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
27
28 sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause, )
@@ -0,0 +1,358 b''
1 """
2 Extensions to SQLAlchemy for altering existing tables.
3
4 At the moment, this isn't so much based off of ANSI as much as
5 things that just happen to work with multiple databases.
6 """
7 import StringIO
8
9 import sqlalchemy as sa
10 from sqlalchemy.schema import SchemaVisitor
11 from sqlalchemy.engine.default import DefaultDialect
12 from sqlalchemy.sql import ClauseElement
13 from sqlalchemy.schema import (ForeignKeyConstraint,
14 PrimaryKeyConstraint,
15 CheckConstraint,
16 UniqueConstraint,
17 Index)
18
19 from migrate import exceptions
20 from migrate.changeset import constraint, SQLA_06
21
22 if not SQLA_06:
23 from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
24 else:
25 from sqlalchemy.schema import AddConstraint, DropConstraint
26 from sqlalchemy.sql.compiler import DDLCompiler
27 SchemaGenerator = SchemaDropper = DDLCompiler
28
29
30 class AlterTableVisitor(SchemaVisitor):
31 """Common operations for ``ALTER TABLE`` statements."""
32
33 if SQLA_06:
34 # engine.Compiler looks for .statement
35 # when it spawns off a new compiler
36 statement = ClauseElement()
37
38 def append(self, s):
39 """Append content to the SchemaIterator's query buffer."""
40
41 self.buffer.write(s)
42
43 def execute(self):
44 """Execute the contents of the SchemaIterator's buffer."""
45 try:
46 return self.connection.execute(self.buffer.getvalue())
47 finally:
48 self.buffer.truncate(0)
49
50 def __init__(self, dialect, connection, **kw):
51 self.connection = connection
52 self.buffer = StringIO.StringIO()
53 self.preparer = dialect.identifier_preparer
54 self.dialect = dialect
55
56 def traverse_single(self, elem):
57 ret = super(AlterTableVisitor, self).traverse_single(elem)
58 if ret:
59 # adapt to 0.6 which uses a string-returning
60 # object
61 self.append(" %s" % ret)
62
63 def _to_table(self, param):
64 """Returns the table object for the given param object."""
65 if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
66 ret = param.table
67 else:
68 ret = param
69 return ret
70
71 def start_alter_table(self, param):
72 """Returns the start of an ``ALTER TABLE`` SQL-Statement.
73
74 Use the param object to determine the table name and use it
75 for building the SQL statement.
76
77 :param param: object to determine the table from
78 :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
79 :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
80 or string (table name)
81 """
82 table = self._to_table(param)
83 self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
84 return table
85
86
87 class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
88 """Extends ansisql generator for column creation (alter table add col)"""
89
90 def visit_column(self, column):
91 """Create a column (table already exists).
92
93 :param column: column object
94 :type column: :class:`sqlalchemy.Column` instance
95 """
96 if column.default is not None:
97 self.traverse_single(column.default)
98
99 table = self.start_alter_table(column)
100 self.append("ADD ")
101 self.append(self.get_column_specification(column))
102
103 for cons in column.constraints:
104 self.traverse_single(cons)
105 self.execute()
106
107 # ALTER TABLE STATEMENTS
108
109 # add indexes and unique constraints
110 if column.index_name:
111 Index(column.index_name,column).create()
112 elif column.unique_name:
113 constraint.UniqueConstraint(column,
114 name=column.unique_name).create()
115
116 # SA bounds FK constraints to table, add manually
117 for fk in column.foreign_keys:
118 self.add_foreignkey(fk.constraint)
119
120 # add primary key constraint if needed
121 if column.primary_key_name:
122 cons = constraint.PrimaryKeyConstraint(column,
123 name=column.primary_key_name)
124 cons.create()
125
126 if SQLA_06:
127 def add_foreignkey(self, fk):
128 self.connection.execute(AddConstraint(fk))
129
130 class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
131 """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
132 DROP COLUMN``).
133 """
134
135 def visit_column(self, column):
136 """Drop a column from its table.
137
138 :param column: the column object
139 :type column: :class:`sqlalchemy.Column`
140 """
141 table = self.start_alter_table(column)
142 self.append('DROP COLUMN %s' % self.preparer.format_column(column))
143 self.execute()
144
145
146 class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
147 """Manages changes to existing schema elements.
148
149 Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
150 is in SchemaGenerator.
151
152 All items may be renamed. Columns can also have many of their properties -
153 type, for example - changed.
154
155 Each function is passed a tuple, containing (object, name); where
156 object is a type of object you'd expect for that function
157 (ie. table for visit_table) and name is the object's new
158 name. NONE means the name is unchanged.
159 """
160
161 def visit_table(self, table):
162 """Rename a table. Other ops aren't supported."""
163 self.start_alter_table(table)
164 self.append("RENAME TO %s" % self.preparer.quote(table.new_name,
165 table.quote))
166 self.execute()
167
168 def visit_index(self, index):
169 """Rename an index"""
170 if hasattr(self, '_validate_identifier'):
171 # SA <= 0.6.3
172 self.append("ALTER INDEX %s RENAME TO %s" % (
173 self.preparer.quote(
174 self._validate_identifier(
175 index.name, True), index.quote),
176 self.preparer.quote(
177 self._validate_identifier(
178 index.new_name, True), index.quote)))
179 else:
180 # SA >= 0.6.5
181 self.append("ALTER INDEX %s RENAME TO %s" % (
182 self.preparer.quote(
183 self._index_identifier(
184 index.name), index.quote),
185 self.preparer.quote(
186 self._index_identifier(
187 index.new_name), index.quote)))
188 self.execute()
189
190 def visit_column(self, delta):
191 """Rename/change a column."""
192 # ALTER COLUMN is implemented as several ALTER statements
193 keys = delta.keys()
194 if 'type' in keys:
195 self._run_subvisit(delta, self._visit_column_type)
196 if 'nullable' in keys:
197 self._run_subvisit(delta, self._visit_column_nullable)
198 if 'server_default' in keys:
199 # Skip 'default': only handle server-side defaults, others
200 # are managed by the app, not the db.
201 self._run_subvisit(delta, self._visit_column_default)
202 if 'name' in keys:
203 self._run_subvisit(delta, self._visit_column_name, start_alter=False)
204
205 def _run_subvisit(self, delta, func, start_alter=True):
206 """Runs visit method based on what needs to be changed on column"""
207 table = self._to_table(delta.table)
208 col_name = delta.current_name
209 if start_alter:
210 self.start_alter_column(table, col_name)
211 ret = func(table, delta.result_column, delta)
212 self.execute()
213
214 def start_alter_column(self, table, col_name):
215 """Starts ALTER COLUMN"""
216 self.start_alter_table(table)
217 self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, table.quote))
218
219 def _visit_column_nullable(self, table, column, delta):
220 nullable = delta['nullable']
221 if nullable:
222 self.append("DROP NOT NULL")
223 else:
224 self.append("SET NOT NULL")
225
226 def _visit_column_default(self, table, column, delta):
227 default_text = self.get_column_default_string(column)
228 if default_text is not None:
229 self.append("SET DEFAULT %s" % default_text)
230 else:
231 self.append("DROP DEFAULT")
232
233 def _visit_column_type(self, table, column, delta):
234 type_ = delta['type']
235 if SQLA_06:
236 type_text = str(type_.compile(dialect=self.dialect))
237 else:
238 type_text = type_.dialect_impl(self.dialect).get_col_spec()
239 self.append("TYPE %s" % type_text)
240
241 def _visit_column_name(self, table, column, delta):
242 self.start_alter_table(table)
243 col_name = self.preparer.quote(delta.current_name, table.quote)
244 new_name = self.preparer.format_column(delta.result_column)
245 self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
246
247
248 class ANSIConstraintCommon(AlterTableVisitor):
249 """
250 Migrate's constraints require a separate creation function from
251 SA's: Migrate's constraints are created independently of a table;
252 SA's are created at the same time as the table.
253 """
254
255 def get_constraint_name(self, cons):
256 """Gets a name for the given constraint.
257
258 If the name is already set it will be used otherwise the
259 constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
260 method is used.
261
262 :param cons: constraint object
263 """
264 if cons.name is not None:
265 ret = cons.name
266 else:
267 ret = cons.name = cons.autoname()
268 return self.preparer.quote(ret, cons.quote)
269
270 def visit_migrate_primary_key_constraint(self, *p, **k):
271 self._visit_constraint(*p, **k)
272
273 def visit_migrate_foreign_key_constraint(self, *p, **k):
274 self._visit_constraint(*p, **k)
275
276 def visit_migrate_check_constraint(self, *p, **k):
277 self._visit_constraint(*p, **k)
278
279 def visit_migrate_unique_constraint(self, *p, **k):
280 self._visit_constraint(*p, **k)
281
282 if SQLA_06:
283 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
284 def _visit_constraint(self, constraint):
285 constraint.name = self.get_constraint_name(constraint)
286 self.append(self.process(AddConstraint(constraint)))
287 self.execute()
288
289 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
290 def _visit_constraint(self, constraint):
291 constraint.name = self.get_constraint_name(constraint)
292 self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
293 self.execute()
294
295 else:
296 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
297
298 def get_constraint_specification(self, cons, **kwargs):
299 """Constaint SQL generators.
300
301 We cannot use SA visitors because they append comma.
302 """
303
304 if isinstance(cons, PrimaryKeyConstraint):
305 if cons.name is not None:
306 self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
307 self.append("PRIMARY KEY ")
308 self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
309 for c in cons))
310 self.define_constraint_deferrability(cons)
311 elif isinstance(cons, ForeignKeyConstraint):
312 self.define_foreign_key(cons)
313 elif isinstance(cons, CheckConstraint):
314 if cons.name is not None:
315 self.append("CONSTRAINT %s " %
316 self.preparer.format_constraint(cons))
317 self.append("CHECK (%s)" % cons.sqltext)
318 self.define_constraint_deferrability(cons)
319 elif isinstance(cons, UniqueConstraint):
320 if cons.name is not None:
321 self.append("CONSTRAINT %s " %
322 self.preparer.format_constraint(cons))
323 self.append("UNIQUE (%s)" % \
324 (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
325 self.define_constraint_deferrability(cons)
326 else:
327 raise exceptions.InvalidConstraintError(cons)
328
329 def _visit_constraint(self, constraint):
330
331 table = self.start_alter_table(constraint)
332 constraint.name = self.get_constraint_name(constraint)
333 self.append("ADD ")
334 self.get_constraint_specification(constraint)
335 self.execute()
336
337
338 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
339
340 def _visit_constraint(self, constraint):
341 self.start_alter_table(constraint)
342 self.append("DROP CONSTRAINT ")
343 constraint.name = self.get_constraint_name(constraint)
344 self.append(self.preparer.format_constraint(constraint))
345 if constraint.cascade:
346 self.cascade_constraint(constraint)
347 self.execute()
348
349 def cascade_constraint(self, constraint):
350 self.append(" CASCADE")
351
352
353 class ANSIDialect(DefaultDialect):
354 columngenerator = ANSIColumnGenerator
355 columndropper = ANSIColumnDropper
356 schemachanger = ANSISchemaChanger
357 constraintgenerator = ANSIConstraintGenerator
358 constraintdropper = ANSIConstraintDropper
@@ -0,0 +1,202 b''
1 """
2 This module defines standalone schema constraint classes.
3 """
4 from sqlalchemy import schema
5
6 from migrate.exceptions import *
7 from migrate.changeset import SQLA_06
8
9 class ConstraintChangeset(object):
10 """Base class for Constraint classes."""
11
12 def _normalize_columns(self, cols, table_name=False):
13 """Given: column objects or names; return col names and
14 (maybe) a table"""
15 colnames = []
16 table = None
17 for col in cols:
18 if isinstance(col, schema.Column):
19 if col.table is not None and table is None:
20 table = col.table
21 if table_name:
22 col = '.'.join((col.table.name, col.name))
23 else:
24 col = col.name
25 colnames.append(col)
26 return colnames, table
27
28 def __do_imports(self, visitor_name, *a, **kw):
29 engine = kw.pop('engine', self.table.bind)
30 from migrate.changeset.databases.visitor import (get_engine_visitor,
31 run_single_visitor)
32 visitorcallable = get_engine_visitor(engine, visitor_name)
33 run_single_visitor(engine, visitorcallable, self, *a, **kw)
34
35 def create(self, *a, **kw):
36 """Create the constraint in the database.
37
38 :param engine: the database engine to use. If this is \
39 :keyword:`None` the instance's engine will be used
40 :type engine: :class:`sqlalchemy.engine.base.Engine`
41 :param connection: reuse connection istead of creating new one.
42 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
43 """
44 # TODO: set the parent here instead of in __init__
45 self.__do_imports('constraintgenerator', *a, **kw)
46
47 def drop(self, *a, **kw):
48 """Drop the constraint from the database.
49
50 :param engine: the database engine to use. If this is
51 :keyword:`None` the instance's engine will be used
52 :param cascade: Issue CASCADE drop if database supports it
53 :type engine: :class:`sqlalchemy.engine.base.Engine`
54 :type cascade: bool
55 :param connection: reuse connection istead of creating new one.
56 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
57 :returns: Instance with cleared columns
58 """
59 self.cascade = kw.pop('cascade', False)
60 self.__do_imports('constraintdropper', *a, **kw)
61 # the spirit of Constraint objects is that they
62 # are immutable (just like in a DB. they're only ADDed
63 # or DROPped).
64 #self.columns.clear()
65 return self
66
67
68 class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
69 """Construct PrimaryKeyConstraint
70
71 Migrate's additional parameters:
72
73 :param cols: Columns in constraint.
74 :param table: If columns are passed as strings, this kw is required
75 :type table: Table instance
76 :type cols: strings or Column instances
77 """
78
79 __migrate_visit_name__ = 'migrate_primary_key_constraint'
80
81 def __init__(self, *cols, **kwargs):
82 colnames, table = self._normalize_columns(cols)
83 table = kwargs.pop('table', table)
84 super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs)
85 if table is not None:
86 self._set_parent(table)
87
88
89 def autoname(self):
90 """Mimic the database's automatic constraint names"""
91 return "%s_pkey" % self.table.name
92
93
94 class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
95 """Construct ForeignKeyConstraint
96
97 Migrate's additional parameters:
98
99 :param columns: Columns in constraint
100 :param refcolumns: Columns that this FK reffers to in another table.
101 :param table: If columns are passed as strings, this kw is required
102 :type table: Table instance
103 :type columns: list of strings or Column instances
104 :type refcolumns: list of strings or Column instances
105 """
106
107 __migrate_visit_name__ = 'migrate_foreign_key_constraint'
108
109 def __init__(self, columns, refcolumns, *args, **kwargs):
110 colnames, table = self._normalize_columns(columns)
111 table = kwargs.pop('table', table)
112 refcolnames, reftable = self._normalize_columns(refcolumns,
113 table_name=True)
114 super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *args,
115 **kwargs)
116 if table is not None:
117 self._set_parent(table)
118
119 @property
120 def referenced(self):
121 return [e.column for e in self.elements]
122
123 @property
124 def reftable(self):
125 return self.referenced[0].table
126
127 def autoname(self):
128 """Mimic the database's automatic constraint names"""
129 if hasattr(self.columns, 'keys'):
130 # SA <= 0.5
131 firstcol = self.columns[self.columns.keys()[0]]
132 ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
133 table=firstcol.table.name,
134 firstcolumn=firstcol.name,)
135 else:
136 # SA >= 0.6
137 ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
138 table=self.table.name,
139 firstcolumn=self.columns[0],)
140 return ret
141
142
143 class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
144 """Construct CheckConstraint
145
146 Migrate's additional parameters:
147
148 :param sqltext: Plain SQL text to check condition
149 :param columns: If not name is applied, you must supply this kw\
150 to autoname constraint
151 :param table: If columns are passed as strings, this kw is required
152 :type table: Table instance
153 :type columns: list of Columns instances
154 :type sqltext: string
155 """
156
157 __migrate_visit_name__ = 'migrate_check_constraint'
158
159 def __init__(self, sqltext, *args, **kwargs):
160 cols = kwargs.pop('columns', [])
161 if not cols and not kwargs.get('name', False):
162 raise InvalidConstraintError('You must either set "name"'
163 'parameter or "columns" to autogenarate it.')
164 colnames, table = self._normalize_columns(cols)
165 table = kwargs.pop('table', table)
166 schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
167 if table is not None:
168 if not SQLA_06:
169 self.table = table
170 self._set_parent(table)
171 self.colnames = colnames
172
173 def autoname(self):
174 return "%(table)s_%(cols)s_check" % \
175 dict(table=self.table.name, cols="_".join(self.colnames))
176
177
178 class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
179 """Construct UniqueConstraint
180
181 Migrate's additional parameters:
182
183 :param cols: Columns in constraint.
184 :param table: If columns are passed as strings, this kw is required
185 :type table: Table instance
186 :type cols: strings or Column instances
187
188 .. versionadded:: 0.6.0
189 """
190
191 __migrate_visit_name__ = 'migrate_unique_constraint'
192
193 def __init__(self, *cols, **kwargs):
194 self.colnames, table = self._normalize_columns(cols)
195 table = kwargs.pop('table', table)
196 super(UniqueConstraint, self).__init__(*self.colnames, **kwargs)
197 if table is not None:
198 self._set_parent(table)
199
200 def autoname(self):
201 """Mimic the database's automatic constraint names"""
202 return "%s_%s_key" % (self.table.name, self.colnames[0])
@@ -0,0 +1,10 b''
1 """
2 This module contains database dialect specific changeset
3 implementations.
4 """
5 __all__ = [
6 'postgres',
7 'sqlite',
8 'mysql',
9 'oracle',
10 ]
@@ -0,0 +1,80 b''
1 """
2 Firebird database specific implementations of changeset classes.
3 """
4 from sqlalchemy.databases import firebird as sa_base
5
6 from migrate import exceptions
7 from migrate.changeset import ansisql, SQLA_06
8
9
10 if SQLA_06:
11 FBSchemaGenerator = sa_base.FBDDLCompiler
12 else:
13 FBSchemaGenerator = sa_base.FBSchemaGenerator
14
15 class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
16 """Firebird column generator implementation."""
17
18
19 class FBColumnDropper(ansisql.ANSIColumnDropper):
20 """Firebird column dropper implementation."""
21
22 def visit_column(self, column):
23 """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax
24
25 Drop primary key and unique constraints if dropped column is referencing it."""
26 if column.primary_key:
27 if column.table.primary_key.columns.contains_column(column):
28 column.table.primary_key.drop()
29 # TODO: recreate primary key if it references more than this column
30 if column.unique or getattr(column, 'unique_name', None):
31 for cons in column.table.constraints:
32 if cons.contains_column(column):
33 cons.drop()
34 # TODO: recreate unique constraint if it refenrences more than this column
35
36 table = self.start_alter_table(column)
37 self.append('DROP %s' % self.preparer.format_column(column))
38 self.execute()
39
40
41 class FBSchemaChanger(ansisql.ANSISchemaChanger):
42 """Firebird schema changer implementation."""
43
44 def visit_table(self, table):
45 """Rename table not supported"""
46 raise exceptions.NotSupportedError(
47 "Firebird does not support renaming tables.")
48
49 def _visit_column_name(self, table, column, delta):
50 self.start_alter_table(table)
51 col_name = self.preparer.quote(delta.current_name, table.quote)
52 new_name = self.preparer.format_column(delta.result_column)
53 self.append('ALTER COLUMN %s TO %s' % (col_name, new_name))
54
55 def _visit_column_nullable(self, table, column, delta):
56 """Changing NULL is not supported"""
57 # TODO: http://www.firebirdfaq.org/faq103/
58 raise exceptions.NotSupportedError(
59 "Firebird does not support altering NULL bevahior.")
60
61
62 class FBConstraintGenerator(ansisql.ANSIConstraintGenerator):
63 """Firebird constraint generator implementation."""
64
65
66 class FBConstraintDropper(ansisql.ANSIConstraintDropper):
67 """Firebird constaint dropper implementation."""
68
69 def cascade_constraint(self, constraint):
70 """Cascading constraints is not supported"""
71 raise exceptions.NotSupportedError(
72 "Firebird does not support cascading constraints")
73
74
75 class FBDialect(ansisql.ANSIDialect):
76 columngenerator = FBColumnGenerator
77 columndropper = FBColumnDropper
78 schemachanger = FBSchemaChanger
79 constraintgenerator = FBConstraintGenerator
80 constraintdropper = FBConstraintDropper
@@ -0,0 +1,94 b''
1 """
2 MySQL database specific implementations of changeset classes.
3 """
4
5 from sqlalchemy.databases import mysql as sa_base
6 from sqlalchemy import types as sqltypes
7
8 from migrate import exceptions
9 from migrate.changeset import ansisql, SQLA_06
10
11
12 if not SQLA_06:
13 MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator
14 else:
15 MySQLSchemaGenerator = sa_base.MySQLDDLCompiler
16
17 class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator):
18 pass
19
20
21 class MySQLColumnDropper(ansisql.ANSIColumnDropper):
22 pass
23
24
25 class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger):
26
27 def visit_column(self, delta):
28 table = delta.table
29 colspec = self.get_column_specification(delta.result_column)
30 if delta.result_column.autoincrement:
31 primary_keys = [c for c in table.primary_key.columns
32 if (c.autoincrement and
33 isinstance(c.type, sqltypes.Integer) and
34 not c.foreign_keys)]
35
36 if primary_keys:
37 first = primary_keys.pop(0)
38 if first.name == delta.current_name:
39 colspec += " AUTO_INCREMENT"
40 old_col_name = self.preparer.quote(delta.current_name, table.quote)
41
42 self.start_alter_table(table)
43
44 self.append("CHANGE COLUMN %s " % old_col_name)
45 self.append(colspec)
46 self.execute()
47
48 def visit_index(self, param):
49 # If MySQL can do this, I can't find how
50 raise exceptions.NotSupportedError("MySQL cannot rename indexes")
51
52
53 class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator):
54 pass
55
56 if SQLA_06:
57 class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper):
58 def visit_migrate_check_constraint(self, *p, **k):
59 raise exceptions.NotSupportedError("MySQL does not support CHECK"
60 " constraints, use triggers instead.")
61
62 else:
63 class MySQLConstraintDropper(ansisql.ANSIConstraintDropper):
64
65 def visit_migrate_primary_key_constraint(self, constraint):
66 self.start_alter_table(constraint)
67 self.append("DROP PRIMARY KEY")
68 self.execute()
69
70 def visit_migrate_foreign_key_constraint(self, constraint):
71 self.start_alter_table(constraint)
72 self.append("DROP FOREIGN KEY ")
73 constraint.name = self.get_constraint_name(constraint)
74 self.append(self.preparer.format_constraint(constraint))
75 self.execute()
76
77 def visit_migrate_check_constraint(self, *p, **k):
78 raise exceptions.NotSupportedError("MySQL does not support CHECK"
79 " constraints, use triggers instead.")
80
81 def visit_migrate_unique_constraint(self, constraint, *p, **k):
82 self.start_alter_table(constraint)
83 self.append('DROP INDEX ')
84 constraint.name = self.get_constraint_name(constraint)
85 self.append(self.preparer.format_constraint(constraint))
86 self.execute()
87
88
89 class MySQLDialect(ansisql.ANSIDialect):
90 columngenerator = MySQLColumnGenerator
91 columndropper = MySQLColumnDropper
92 schemachanger = MySQLSchemaChanger
93 constraintgenerator = MySQLConstraintGenerator
94 constraintdropper = MySQLConstraintDropper
@@ -0,0 +1,111 b''
1 """
2 Oracle database specific implementations of changeset classes.
3 """
4 import sqlalchemy as sa
5 from sqlalchemy.databases import oracle as sa_base
6
7 from migrate import exceptions
8 from migrate.changeset import ansisql, SQLA_06
9
10
11 if not SQLA_06:
12 OracleSchemaGenerator = sa_base.OracleSchemaGenerator
13 else:
14 OracleSchemaGenerator = sa_base.OracleDDLCompiler
15
16
17 class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
18 pass
19
20
21 class OracleColumnDropper(ansisql.ANSIColumnDropper):
22 pass
23
24
25 class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger):
26
27 def get_column_specification(self, column, **kwargs):
28 # Ignore the NOT NULL generated
29 override_nullable = kwargs.pop('override_nullable', None)
30 if override_nullable:
31 orig = column.nullable
32 column.nullable = True
33 ret = super(OracleSchemaChanger, self).get_column_specification(
34 column, **kwargs)
35 if override_nullable:
36 column.nullable = orig
37 return ret
38
39 def visit_column(self, delta):
40 keys = delta.keys()
41
42 if 'name' in keys:
43 self._run_subvisit(delta,
44 self._visit_column_name,
45 start_alter=False)
46
47 if len(set(('type', 'nullable', 'server_default')).intersection(keys)):
48 self._run_subvisit(delta,
49 self._visit_column_change,
50 start_alter=False)
51
52 def _visit_column_change(self, table, column, delta):
53 # Oracle cannot drop a default once created, but it can set it
54 # to null. We'll do that if default=None
55 # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234
56 dropdefault_hack = (column.server_default is None \
57 and 'server_default' in delta.keys())
58 # Oracle apparently doesn't like it when we say "not null" if
59 # the column's already not null. Fudge it, so we don't need a
60 # new function
61 notnull_hack = ((not column.nullable) \
62 and ('nullable' not in delta.keys()))
63 # We need to specify NULL if we're removing a NOT NULL
64 # constraint
65 null_hack = (column.nullable and ('nullable' in delta.keys()))
66
67 if dropdefault_hack:
68 column.server_default = sa.PassiveDefault(sa.sql.null())
69 if notnull_hack:
70 column.nullable = True
71 colspec = self.get_column_specification(column,
72 override_nullable=null_hack)
73 if null_hack:
74 colspec += ' NULL'
75 if notnull_hack:
76 column.nullable = False
77 if dropdefault_hack:
78 column.server_default = None
79
80 self.start_alter_table(table)
81 self.append("MODIFY (")
82 self.append(colspec)
83 self.append(")")
84
85
86 class OracleConstraintCommon(object):
87
88 def get_constraint_name(self, cons):
89 # Oracle constraints can't guess their name like other DBs
90 if not cons.name:
91 raise exceptions.NotSupportedError(
92 "Oracle constraint names must be explicitly stated")
93 return cons.name
94
95
96 class OracleConstraintGenerator(OracleConstraintCommon,
97 ansisql.ANSIConstraintGenerator):
98 pass
99
100
101 class OracleConstraintDropper(OracleConstraintCommon,
102 ansisql.ANSIConstraintDropper):
103 pass
104
105
106 class OracleDialect(ansisql.ANSIDialect):
107 columngenerator = OracleColumnGenerator
108 columndropper = OracleColumnDropper
109 schemachanger = OracleSchemaChanger
110 constraintgenerator = OracleConstraintGenerator
111 constraintdropper = OracleConstraintDropper
@@ -0,0 +1,46 b''
1 """
2 `PostgreSQL`_ database specific implementations of changeset classes.
3
4 .. _`PostgreSQL`: http://www.postgresql.org/
5 """
6 from migrate.changeset import ansisql, SQLA_06
7
8 if not SQLA_06:
9 from sqlalchemy.databases import postgres as sa_base
10 PGSchemaGenerator = sa_base.PGSchemaGenerator
11 else:
12 from sqlalchemy.databases import postgresql as sa_base
13 PGSchemaGenerator = sa_base.PGDDLCompiler
14
15
16 class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator):
17 """PostgreSQL column generator implementation."""
18 pass
19
20
21 class PGColumnDropper(ansisql.ANSIColumnDropper):
22 """PostgreSQL column dropper implementation."""
23 pass
24
25
26 class PGSchemaChanger(ansisql.ANSISchemaChanger):
27 """PostgreSQL schema changer implementation."""
28 pass
29
30
31 class PGConstraintGenerator(ansisql.ANSIConstraintGenerator):
32 """PostgreSQL constraint generator implementation."""
33 pass
34
35
36 class PGConstraintDropper(ansisql.ANSIConstraintDropper):
37 """PostgreSQL constaint dropper implementation."""
38 pass
39
40
41 class PGDialect(ansisql.ANSIDialect):
42 columngenerator = PGColumnGenerator
43 columndropper = PGColumnDropper
44 schemachanger = PGSchemaChanger
45 constraintgenerator = PGConstraintGenerator
46 constraintdropper = PGConstraintDropper
@@ -0,0 +1,148 b''
1 """
2 `SQLite`_ database specific implementations of changeset classes.
3
4 .. _`SQLite`: http://www.sqlite.org/
5 """
6 from UserDict import DictMixin
7 from copy import copy
8
9 from sqlalchemy.databases import sqlite as sa_base
10
11 from migrate import exceptions
12 from migrate.changeset import ansisql, SQLA_06
13
14
15 if not SQLA_06:
16 SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
17 else:
18 SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
19
20 class SQLiteCommon(object):
21
22 def _not_supported(self, op):
23 raise exceptions.NotSupportedError("SQLite does not support "
24 "%s; see http://www.sqlite.org/lang_altertable.html" % op)
25
26
27 class SQLiteHelper(SQLiteCommon):
28
29 def recreate_table(self,table,column=None,delta=None):
30 table_name = self.preparer.format_table(table)
31
32 # we remove all indexes so as not to have
33 # problems during copy and re-create
34 for index in table.indexes:
35 index.drop()
36
37 self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
38 self.execute()
39
40 insertion_string = self._modify_table(table, column, delta)
41
42 table.create()
43 self.append(insertion_string % {'table_name': table_name})
44 self.execute()
45 self.append('DROP TABLE migration_tmp')
46 self.execute()
47
48 def visit_column(self, delta):
49 if isinstance(delta, DictMixin):
50 column = delta.result_column
51 table = self._to_table(delta.table)
52 else:
53 column = delta
54 table = self._to_table(column.table)
55 self.recreate_table(table,column,delta)
56
57 class SQLiteColumnGenerator(SQLiteSchemaGenerator,
58 ansisql.ANSIColumnGenerator,
59 # at the end so we get the normal
60 # visit_column by default
61 SQLiteHelper,
62 SQLiteCommon
63 ):
64 """SQLite ColumnGenerator"""
65
66 def _modify_table(self, table, column, delta):
67 columns = ' ,'.join(map(
68 self.preparer.format_column,
69 [c for c in table.columns if c.name!=column.name]))
70 return ('INSERT INTO %%(table_name)s (%(cols)s) '
71 'SELECT %(cols)s from migration_tmp')%{'cols':columns}
72
73 def visit_column(self,column):
74 if column.foreign_keys:
75 SQLiteHelper.visit_column(self,column)
76 else:
77 super(SQLiteColumnGenerator,self).visit_column(column)
78
79 class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper):
80 """SQLite ColumnDropper"""
81
82 def _modify_table(self, table, column, delta):
83 columns = ' ,'.join(map(self.preparer.format_column, table.columns))
84 return 'INSERT INTO %(table_name)s SELECT ' + columns + \
85 ' from migration_tmp'
86
87
88 class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger):
89 """SQLite SchemaChanger"""
90
91 def _modify_table(self, table, column, delta):
92 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
93
94 def visit_index(self, index):
95 """Does not support ALTER INDEX"""
96 self._not_supported('ALTER INDEX')
97
98
99 class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon):
100
101 def visit_migrate_primary_key_constraint(self, constraint):
102 tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )"
103 cols = ', '.join(map(self.preparer.format_column, constraint.columns))
104 tname = self.preparer.format_table(constraint.table)
105 name = self.get_constraint_name(constraint)
106 msg = tmpl % (name, tname, cols)
107 self.append(msg)
108 self.execute()
109
110 def _modify_table(self, table, column, delta):
111 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
112
113 def visit_migrate_foreign_key_constraint(self, *p, **k):
114 self.recreate_table(p[0].table)
115
116 def visit_migrate_unique_constraint(self, *p, **k):
117 self.recreate_table(p[0].table)
118
119
120 class SQLiteConstraintDropper(ansisql.ANSIColumnDropper,
121 SQLiteCommon,
122 ansisql.ANSIConstraintCommon):
123
124 def visit_migrate_primary_key_constraint(self, constraint):
125 tmpl = "DROP INDEX %s "
126 name = self.get_constraint_name(constraint)
127 msg = tmpl % (name)
128 self.append(msg)
129 self.execute()
130
131 def visit_migrate_foreign_key_constraint(self, *p, **k):
132 self._not_supported('ALTER TABLE DROP CONSTRAINT')
133
134 def visit_migrate_check_constraint(self, *p, **k):
135 self._not_supported('ALTER TABLE DROP CONSTRAINT')
136
137 def visit_migrate_unique_constraint(self, *p, **k):
138 self._not_supported('ALTER TABLE DROP CONSTRAINT')
139
140
141 # TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index
142
143 class SQLiteDialect(ansisql.ANSIDialect):
144 columngenerator = SQLiteColumnGenerator
145 columndropper = SQLiteColumnDropper
146 schemachanger = SQLiteSchemaChanger
147 constraintgenerator = SQLiteConstraintGenerator
148 constraintdropper = SQLiteConstraintDropper
@@ -0,0 +1,78 b''
1 """
2 Module for visitor class mapping.
3 """
4 import sqlalchemy as sa
5
6 from migrate.changeset import ansisql
7 from migrate.changeset.databases import (sqlite,
8 postgres,
9 mysql,
10 oracle,
11 firebird)
12
13
14 # Map SA dialects to the corresponding Migrate extensions
15 DIALECTS = {
16 "default": ansisql.ANSIDialect,
17 "sqlite": sqlite.SQLiteDialect,
18 "postgres": postgres.PGDialect,
19 "postgresql": postgres.PGDialect,
20 "mysql": mysql.MySQLDialect,
21 "oracle": oracle.OracleDialect,
22 "firebird": firebird.FBDialect,
23 }
24
25
26 def get_engine_visitor(engine, name):
27 """
28 Get the visitor implementation for the given database engine.
29
30 :param engine: SQLAlchemy Engine
31 :param name: Name of the visitor
32 :type name: string
33 :type engine: Engine
34 :returns: visitor
35 """
36 # TODO: link to supported visitors
37 return get_dialect_visitor(engine.dialect, name)
38
39
40 def get_dialect_visitor(sa_dialect, name):
41 """
42 Get the visitor implementation for the given dialect.
43
44 Finds the visitor implementation based on the dialect class and
45 returns and instance initialized with the given name.
46
47 Binds dialect specific preparer to visitor.
48 """
49
50 # map sa dialect to migrate dialect and return visitor
51 sa_dialect_name = getattr(sa_dialect, 'name', 'default')
52 migrate_dialect_cls = DIALECTS[sa_dialect_name]
53 visitor = getattr(migrate_dialect_cls, name)
54
55 # bind preparer
56 visitor.preparer = sa_dialect.preparer(sa_dialect)
57
58 return visitor
59
60 def run_single_visitor(engine, visitorcallable, element,
61 connection=None, **kwargs):
62 """Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor`
63 with support for migrate visitors.
64 """
65 if connection is None:
66 conn = engine.contextual_connect(close_with_result=False)
67 else:
68 conn = connection
69 visitor = visitorcallable(engine.dialect, conn)
70 try:
71 if hasattr(element, '__migrate_visit_name__'):
72 fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__)
73 else:
74 fn = getattr(visitor, 'visit_' + element.__visit_name__)
75 fn(element, **kwargs)
76 finally:
77 if connection is None:
78 conn.close()
This diff has been collapsed as it changes many lines, (669 lines changed) Show them Hide them
@@ -0,0 +1,669 b''
1 """
2 Schema module providing common schema operations.
3 """
4 import warnings
5
6 from UserDict import DictMixin
7
8 import sqlalchemy
9
10 from sqlalchemy.schema import ForeignKeyConstraint
11 from sqlalchemy.schema import UniqueConstraint
12
13 from migrate.exceptions import *
14 from migrate.changeset import SQLA_06
15 from migrate.changeset.databases.visitor import (get_engine_visitor,
16 run_single_visitor)
17
18
19 __all__ = [
20 'create_column',
21 'drop_column',
22 'alter_column',
23 'rename_table',
24 'rename_index',
25 'ChangesetTable',
26 'ChangesetColumn',
27 'ChangesetIndex',
28 'ChangesetDefaultClause',
29 'ColumnDelta',
30 ]
31
32 DEFAULT_ALTER_METADATA = True
33
34
35 def create_column(column, table=None, *p, **kw):
36 """Create a column, given the table.
37
38 API to :meth:`ChangesetColumn.create`.
39 """
40 if table is not None:
41 return table.create_column(column, *p, **kw)
42 return column.create(*p, **kw)
43
44
45 def drop_column(column, table=None, *p, **kw):
46 """Drop a column, given the table.
47
48 API to :meth:`ChangesetColumn.drop`.
49 """
50 if table is not None:
51 return table.drop_column(column, *p, **kw)
52 return column.drop(*p, **kw)
53
54
55 def rename_table(table, name, engine=None, **kw):
56 """Rename a table.
57
58 If Table instance is given, engine is not used.
59
60 API to :meth:`ChangesetTable.rename`.
61
62 :param table: Table to be renamed.
63 :param name: New name for Table.
64 :param engine: Engine instance.
65 :type table: string or Table instance
66 :type name: string
67 :type engine: obj
68 """
69 table = _to_table(table, engine)
70 table.rename(name, **kw)
71
72
73 def rename_index(index, name, table=None, engine=None, **kw):
74 """Rename an index.
75
76 If Index instance is given,
77 table and engine are not used.
78
79 API to :meth:`ChangesetIndex.rename`.
80
81 :param index: Index to be renamed.
82 :param name: New name for index.
83 :param table: Table to which Index is reffered.
84 :param engine: Engine instance.
85 :type index: string or Index instance
86 :type name: string
87 :type table: string or Table instance
88 :type engine: obj
89 """
90 index = _to_index(index, table, engine)
91 index.rename(name, **kw)
92
93
94 def alter_column(*p, **k):
95 """Alter a column.
96
97 This is a helper function that creates a :class:`ColumnDelta` and
98 runs it.
99
100 :argument column:
101 The name of the column to be altered or a
102 :class:`ChangesetColumn` column representing it.
103
104 :param table:
105 A :class:`~sqlalchemy.schema.Table` or table name to
106 for the table where the column will be changed.
107
108 :param engine:
109 The :class:`~sqlalchemy.engine.base.Engine` to use for table
110 reflection and schema alterations.
111
112 :param alter_metadata:
113 If `True`, which is the default, the
114 :class:`~sqlalchemy.schema.Column` will also modified.
115 If `False`, the :class:`~sqlalchemy.schema.Column` will be left
116 as it was.
117
118 :returns: A :class:`ColumnDelta` instance representing the change.
119
120
121 """
122
123 k.setdefault('alter_metadata', DEFAULT_ALTER_METADATA)
124
125 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
126 k['table'] = p[0].table
127 if 'engine' not in k:
128 k['engine'] = k['table'].bind
129
130 # deprecation
131 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
132 warnings.warn(
133 "Passing a Column object to alter_column is deprecated."
134 " Just pass in keyword parameters instead.",
135 MigrateDeprecationWarning
136 )
137 engine = k['engine']
138 delta = ColumnDelta(*p, **k)
139
140 visitorcallable = get_engine_visitor(engine, 'schemachanger')
141 engine._run_visitor(visitorcallable, delta)
142
143 return delta
144
145
146 def _to_table(table, engine=None):
147 """Return if instance of Table, else construct new with metadata"""
148 if isinstance(table, sqlalchemy.Table):
149 return table
150
151 # Given: table name, maybe an engine
152 meta = sqlalchemy.MetaData()
153 if engine is not None:
154 meta.bind = engine
155 return sqlalchemy.Table(table, meta)
156
157
158 def _to_index(index, table=None, engine=None):
159 """Return if instance of Index, else construct new with metadata"""
160 if isinstance(index, sqlalchemy.Index):
161 return index
162
163 # Given: index name; table name required
164 table = _to_table(table, engine)
165 ret = sqlalchemy.Index(index)
166 ret.table = table
167 return ret
168
169
170 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
171 """Extracts the differences between two columns/column-parameters
172
173 May receive parameters arranged in several different ways:
174
175 * **current_column, new_column, \*p, \*\*kw**
176 Additional parameters can be specified to override column
177 differences.
178
179 * **current_column, \*p, \*\*kw**
180 Additional parameters alter current_column. Table name is extracted
181 from current_column object.
182 Name is changed to current_column.name from current_name,
183 if current_name is specified.
184
185 * **current_col_name, \*p, \*\*kw**
186 Table kw must specified.
187
188 :param table: Table at which current Column should be bound to.\
189 If table name is given, reflection will be used.
190 :type table: string or Table instance
191 :param alter_metadata: If True, it will apply changes to metadata.
192 :type alter_metadata: bool
193 :param metadata: If `alter_metadata` is true, \
194 metadata is used to reflect table names into
195 :type metadata: :class:`MetaData` instance
196 :param engine: When reflecting tables, either engine or metadata must \
197 be specified to acquire engine object.
198 :type engine: :class:`Engine` instance
199 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
200 `result_column` through :func:`dict` alike object.
201
202 * :class:`ColumnDelta`.result_column is altered column with new attributes
203
204 * :class:`ColumnDelta`.current_name is current name of column in db
205
206
207 """
208
209 # Column attributes that can be altered
210 diff_keys = ('name', 'type', 'primary_key', 'nullable',
211 'server_onupdate', 'server_default', 'autoincrement')
212 diffs = dict()
213 __visit_name__ = 'column'
214
215 def __init__(self, *p, **kw):
216 self.alter_metadata = kw.pop("alter_metadata", False)
217 self.meta = kw.pop("metadata", None)
218 self.engine = kw.pop("engine", None)
219
220 # Things are initialized differently depending on how many column
221 # parameters are given. Figure out how many and call the appropriate
222 # method.
223 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
224 # At least one column specified
225 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
226 # Two columns specified
227 diffs = self.compare_2_columns(*p, **kw)
228 else:
229 # Exactly one column specified
230 diffs = self.compare_1_column(*p, **kw)
231 else:
232 # Zero columns specified
233 if not len(p) or not isinstance(p[0], basestring):
234 raise ValueError("First argument must be column name")
235 diffs = self.compare_parameters(*p, **kw)
236
237 self.apply_diffs(diffs)
238
239 def __repr__(self):
240 return '<ColumnDelta altermetadata=%r, %s>' % (self.alter_metadata,
241 super(ColumnDelta, self).__repr__())
242
243 def __getitem__(self, key):
244 if key not in self.keys():
245 raise KeyError("No such diff key, available: %s" % self.diffs)
246 return getattr(self.result_column, key)
247
248 def __setitem__(self, key, value):
249 if key not in self.keys():
250 raise KeyError("No such diff key, available: %s" % self.diffs)
251 setattr(self.result_column, key, value)
252
253 def __delitem__(self, key):
254 raise NotImplementedError
255
256 def keys(self):
257 return self.diffs.keys()
258
259 def compare_parameters(self, current_name, *p, **k):
260 """Compares Column objects with reflection"""
261 self.table = k.pop('table')
262 self.result_column = self._table.c.get(current_name)
263 if len(p):
264 k = self._extract_parameters(p, k, self.result_column)
265 return k
266
267 def compare_1_column(self, col, *p, **k):
268 """Compares one Column object"""
269 self.table = k.pop('table', None)
270 if self.table is None:
271 self.table = col.table
272 self.result_column = col
273 if len(p):
274 k = self._extract_parameters(p, k, self.result_column)
275 return k
276
277 def compare_2_columns(self, old_col, new_col, *p, **k):
278 """Compares two Column objects"""
279 self.process_column(new_col)
280 self.table = k.pop('table', None)
281 # we cannot use bool() on table in SA06
282 if self.table is None:
283 self.table = old_col.table
284 if self.table is None:
285 new_col.table
286 self.result_column = old_col
287
288 # set differences
289 # leave out some stuff for later comp
290 for key in (set(self.diff_keys) - set(('type',))):
291 val = getattr(new_col, key, None)
292 if getattr(self.result_column, key, None) != val:
293 k.setdefault(key, val)
294
295 # inspect types
296 if not self.are_column_types_eq(self.result_column.type, new_col.type):
297 k.setdefault('type', new_col.type)
298
299 if len(p):
300 k = self._extract_parameters(p, k, self.result_column)
301 return k
302
303 def apply_diffs(self, diffs):
304 """Populate dict and column object with new values"""
305 self.diffs = diffs
306 for key in self.diff_keys:
307 if key in diffs:
308 setattr(self.result_column, key, diffs[key])
309
310 self.process_column(self.result_column)
311
312 # create an instance of class type if not yet
313 if 'type' in diffs and callable(self.result_column.type):
314 self.result_column.type = self.result_column.type()
315
316 # add column to the table
317 if self.table is not None and self.alter_metadata:
318 self.result_column.add_to_table(self.table)
319
320 def are_column_types_eq(self, old_type, new_type):
321 """Compares two types to be equal"""
322 ret = old_type.__class__ == new_type.__class__
323
324 # String length is a special case
325 if ret and isinstance(new_type, sqlalchemy.types.String):
326 ret = (getattr(old_type, 'length', None) == \
327 getattr(new_type, 'length', None))
328 return ret
329
330 def _extract_parameters(self, p, k, column):
331 """Extracts data from p and modifies diffs"""
332 p = list(p)
333 while len(p):
334 if isinstance(p[0], basestring):
335 k.setdefault('name', p.pop(0))
336 elif isinstance(p[0], sqlalchemy.types.AbstractType):
337 k.setdefault('type', p.pop(0))
338 elif callable(p[0]):
339 p[0] = p[0]()
340 else:
341 break
342
343 if len(p):
344 new_col = column.copy_fixed()
345 new_col._init_items(*p)
346 k = self.compare_2_columns(column, new_col, **k)
347 return k
348
349 def process_column(self, column):
350 """Processes default values for column"""
351 # XXX: this is a snippet from SA processing of positional parameters
352 if not SQLA_06 and column.args:
353 toinit = list(column.args)
354 else:
355 toinit = list()
356
357 if column.server_default is not None:
358 if isinstance(column.server_default, sqlalchemy.FetchedValue):
359 toinit.append(column.server_default)
360 else:
361 toinit.append(sqlalchemy.DefaultClause(column.server_default))
362 if column.server_onupdate is not None:
363 if isinstance(column.server_onupdate, FetchedValue):
364 toinit.append(column.server_default)
365 else:
366 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
367 for_update=True))
368 if toinit:
369 column._init_items(*toinit)
370
371 if not SQLA_06:
372 column.args = []
373
374 def _get_table(self):
375 return getattr(self, '_table', None)
376
377 def _set_table(self, table):
378 if isinstance(table, basestring):
379 if self.alter_metadata:
380 if not self.meta:
381 raise ValueError("metadata must be specified for table"
382 " reflection when using alter_metadata")
383 meta = self.meta
384 if self.engine:
385 meta.bind = self.engine
386 else:
387 if not self.engine and not self.meta:
388 raise ValueError("engine or metadata must be specified"
389 " to reflect tables")
390 if not self.engine:
391 self.engine = self.meta.bind
392 meta = sqlalchemy.MetaData(bind=self.engine)
393 self._table = sqlalchemy.Table(table, meta, autoload=True)
394 elif isinstance(table, sqlalchemy.Table):
395 self._table = table
396 if not self.alter_metadata:
397 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
398
399 def _get_result_column(self):
400 return getattr(self, '_result_column', None)
401
402 def _set_result_column(self, column):
403 """Set Column to Table based on alter_metadata evaluation."""
404 self.process_column(column)
405 if not hasattr(self, 'current_name'):
406 self.current_name = column.name
407 if self.alter_metadata:
408 self._result_column = column
409 else:
410 self._result_column = column.copy_fixed()
411
412 table = property(_get_table, _set_table)
413 result_column = property(_get_result_column, _set_result_column)
414
415
416 class ChangesetTable(object):
417 """Changeset extensions to SQLAlchemy tables."""
418
419 def create_column(self, column, *p, **kw):
420 """Creates a column.
421
422 The column parameter may be a column definition or the name of
423 a column in this table.
424
425 API to :meth:`ChangesetColumn.create`
426
427 :param column: Column to be created
428 :type column: Column instance or string
429 """
430 if not isinstance(column, sqlalchemy.Column):
431 # It's a column name
432 column = getattr(self.c, str(column))
433 column.create(table=self, *p, **kw)
434
435 def drop_column(self, column, *p, **kw):
436 """Drop a column, given its name or definition.
437
438 API to :meth:`ChangesetColumn.drop`
439
440 :param column: Column to be droped
441 :type column: Column instance or string
442 """
443 if not isinstance(column, sqlalchemy.Column):
444 # It's a column name
445 try:
446 column = getattr(self.c, str(column))
447 except AttributeError:
448 # That column isn't part of the table. We don't need
449 # its entire definition to drop the column, just its
450 # name, so create a dummy column with the same name.
451 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
452 column.drop(table=self, *p, **kw)
453
454 def rename(self, name, connection=None, **kwargs):
455 """Rename this table.
456
457 :param name: New name of the table.
458 :type name: string
459 :param alter_metadata: If True, table will be removed from metadata
460 :type alter_metadata: bool
461 :param connection: reuse connection istead of creating new one.
462 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
463 """
464 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
465 engine = self.bind
466 self.new_name = name
467 visitorcallable = get_engine_visitor(engine, 'schemachanger')
468 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
469
470 # Fix metadata registration
471 if self.alter_metadata:
472 self.name = name
473 self.deregister()
474 self._set_parent(self.metadata)
475
476 def _meta_key(self):
477 return sqlalchemy.schema._get_table_key(self.name, self.schema)
478
479 def deregister(self):
480 """Remove this table from its metadata"""
481 key = self._meta_key()
482 meta = self.metadata
483 if key in meta.tables:
484 del meta.tables[key]
485
486
487 class ChangesetColumn(object):
488 """Changeset extensions to SQLAlchemy columns."""
489
490 def alter(self, *p, **k):
491 """Makes a call to :func:`alter_column` for the column this
492 method is called on.
493 """
494 if 'table' not in k:
495 k['table'] = self.table
496 if 'engine' not in k:
497 k['engine'] = k['table'].bind
498 return alter_column(self, *p, **k)
499
500 def create(self, table=None, index_name=None, unique_name=None,
501 primary_key_name=None, populate_default=True, connection=None, **kwargs):
502 """Create this column in the database.
503
504 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
505 for most databases.
506
507 :param table: Table instance to create on.
508 :param index_name: Creates :class:`ChangesetIndex` on this column.
509 :param unique_name: Creates :class:\
510 `~migrate.changeset.constraint.UniqueConstraint` on this column.
511 :param primary_key_name: Creates :class:\
512 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
513 :param alter_metadata: If True, column will be added to table object.
514 :param populate_default: If True, created column will be \
515 populated with defaults
516 :param connection: reuse connection istead of creating new one.
517 :type table: Table instance
518 :type index_name: string
519 :type unique_name: string
520 :type primary_key_name: string
521 :type alter_metadata: bool
522 :type populate_default: bool
523 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
524
525 :returns: self
526 """
527 self.populate_default = populate_default
528 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
529 self.index_name = index_name
530 self.unique_name = unique_name
531 self.primary_key_name = primary_key_name
532 for cons in ('index_name', 'unique_name', 'primary_key_name'):
533 self._check_sanity_constraints(cons)
534
535 if self.alter_metadata:
536 self.add_to_table(table)
537 engine = self.table.bind
538 visitorcallable = get_engine_visitor(engine, 'columngenerator')
539 engine._run_visitor(visitorcallable, self, connection, **kwargs)
540
541 # TODO: reuse existing connection
542 if self.populate_default and self.default is not None:
543 stmt = table.update().values({self: engine._execute_default(self.default)})
544 engine.execute(stmt)
545
546 return self
547
548 def drop(self, table=None, connection=None, **kwargs):
549 """Drop this column from the database, leaving its table intact.
550
551 ``ALTER TABLE DROP COLUMN``, for most databases.
552
553 :param alter_metadata: If True, column will be removed from table object.
554 :type alter_metadata: bool
555 :param connection: reuse connection istead of creating new one.
556 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
557 """
558 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
559 if table is not None:
560 self.table = table
561 engine = self.table.bind
562 if self.alter_metadata:
563 self.remove_from_table(self.table, unset_table=False)
564 visitorcallable = get_engine_visitor(engine, 'columndropper')
565 engine._run_visitor(visitorcallable, self, connection, **kwargs)
566 if self.alter_metadata:
567 self.table = None
568 return self
569
570 def add_to_table(self, table):
571 if table is not None and self.table is None:
572 self._set_parent(table)
573
574 def _col_name_in_constraint(self, cons, name):
575 return False
576
577 def remove_from_table(self, table, unset_table=True):
578 # TODO: remove primary keys, constraints, etc
579 if unset_table:
580 self.table = None
581
582 to_drop = set()
583 for index in table.indexes:
584 columns = []
585 for col in index.columns:
586 if col.name != self.name:
587 columns.append(col)
588 if columns:
589 index.columns = columns
590 else:
591 to_drop.add(index)
592 table.indexes = table.indexes - to_drop
593
594 to_drop = set()
595 for cons in table.constraints:
596 # TODO: deal with other types of constraint
597 if isinstance(cons, (ForeignKeyConstraint,
598 UniqueConstraint)):
599 for col_name in cons.columns:
600 if not isinstance(col_name, basestring):
601 col_name = col_name.name
602 if self.name == col_name:
603 to_drop.add(cons)
604 table.constraints = table.constraints - to_drop
605
606 if table.c.contains_column(self):
607 table.c.remove(self)
608
609 # TODO: this is fixed in 0.6
610 def copy_fixed(self, **kw):
611 """Create a copy of this ``Column``, with all attributes."""
612 return sqlalchemy.Column(self.name, self.type, self.default,
613 key=self.key,
614 primary_key=self.primary_key,
615 nullable=self.nullable,
616 quote=self.quote,
617 index=self.index,
618 unique=self.unique,
619 onupdate=self.onupdate,
620 autoincrement=self.autoincrement,
621 server_default=self.server_default,
622 server_onupdate=self.server_onupdate,
623 *[c.copy(**kw) for c in self.constraints])
624
625 def _check_sanity_constraints(self, name):
626 """Check if constraints names are correct"""
627 obj = getattr(self, name)
628 if (getattr(self, name[:-5]) and not obj):
629 raise InvalidConstraintError("Column.create() accepts index_name,"
630 " primary_key_name and unique_name to generate constraints")
631 if not isinstance(obj, basestring) and obj is not None:
632 raise InvalidConstraintError(
633 "%s argument for column must be constraint name" % name)
634
635
636 class ChangesetIndex(object):
637 """Changeset extensions to SQLAlchemy Indexes."""
638
639 __visit_name__ = 'index'
640
641 def rename(self, name, connection=None, **kwargs):
642 """Change the name of an index.
643
644 :param name: New name of the Index.
645 :type name: string
646 :param alter_metadata: If True, Index object will be altered.
647 :type alter_metadata: bool
648 :param connection: reuse connection istead of creating new one.
649 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
650 """
651 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
652 engine = self.table.bind
653 self.new_name = name
654 visitorcallable = get_engine_visitor(engine, 'schemachanger')
655 engine._run_visitor(visitorcallable, self, connection, **kwargs)
656 if self.alter_metadata:
657 self.name = name
658
659
660 class ChangesetDefaultClause(object):
661 """Implements comparison between :class:`DefaultClause` instances"""
662
663 def __eq__(self, other):
664 if isinstance(other, self.__class__):
665 if self.arg == other.arg:
666 return True
667
668 def __ne__(self, other):
669 return not self.__eq__(other)
@@ -0,0 +1,87 b''
1 """
2 Provide exception classes for :mod:`migrate`
3 """
4
5
6 class Error(Exception):
7 """Error base class."""
8
9
10 class ApiError(Error):
11 """Base class for API errors."""
12
13
14 class KnownError(ApiError):
15 """A known error condition."""
16
17
18 class UsageError(ApiError):
19 """A known error condition where help should be displayed."""
20
21
22 class ControlledSchemaError(Error):
23 """Base class for controlled schema errors."""
24
25
26 class InvalidVersionError(ControlledSchemaError):
27 """Invalid version number."""
28
29
30 class DatabaseNotControlledError(ControlledSchemaError):
31 """Database should be under version control, but it's not."""
32
33
34 class DatabaseAlreadyControlledError(ControlledSchemaError):
35 """Database shouldn't be under version control, but it is"""
36
37
38 class WrongRepositoryError(ControlledSchemaError):
39 """This database is under version control by another repository."""
40
41
42 class NoSuchTableError(ControlledSchemaError):
43 """The table does not exist."""
44
45
46 class PathError(Error):
47 """Base class for path errors."""
48
49
50 class PathNotFoundError(PathError):
51 """A path with no file was required; found a file."""
52
53
54 class PathFoundError(PathError):
55 """A path with a file was required; found no file."""
56
57
58 class RepositoryError(Error):
59 """Base class for repository errors."""
60
61
62 class InvalidRepositoryError(RepositoryError):
63 """Invalid repository error."""
64
65
66 class ScriptError(Error):
67 """Base class for script errors."""
68
69
70 class InvalidScriptError(ScriptError):
71 """Invalid script error."""
72
73
74 class InvalidVersionError(Error):
75 """Invalid version error."""
76
77 # migrate.changeset
78
79 class NotSupportedError(Error):
80 """Not supported error"""
81
82
83 class InvalidConstraintError(Error):
84 """Invalid constraint error"""
85
86 class MigrateDeprecationWarning(DeprecationWarning):
87 """Warning for deprecated features in Migrate"""
@@ -0,0 +1,5 b''
1 """
2 This package provides functionality to create and manage
3 repositories of database schema changesets and to apply these
4 changesets to databases.
5 """
@@ -0,0 +1,383 b''
1 """
2 This module provides an external API to the versioning system.
3
4 .. versionchanged:: 0.6.0
5 :func:`migrate.versioning.api.test` and schema diff functions
6 changed order of positional arguments so all accept `url` and `repository`
7 as first arguments.
8
9 .. versionchanged:: 0.5.4
10 ``--preview_sql`` displays source file when using SQL scripts.
11 If Python script is used, it runs the action with mocked engine and
12 returns captured SQL statements.
13
14 .. versionchanged:: 0.5.4
15 Deprecated ``--echo`` parameter in favour of new
16 :func:`migrate.versioning.util.construct_engine` behavior.
17 """
18
19 # Dear migrate developers,
20 #
21 # please do not comment this module using sphinx syntax because its
22 # docstrings are presented as user help and most users cannot
23 # interpret sphinx annotated ReStructuredText.
24 #
25 # Thanks,
26 # Jan Dittberner
27
28 import sys
29 import inspect
30 import logging
31
32 from migrate import exceptions
33 from migrate.versioning import (repository, schema, version,
34 script as script_) # command name conflict
35 from migrate.versioning.util import catch_known_errors, with_engine
36
37
38 log = logging.getLogger(__name__)
39 command_desc = {
40 'help': 'displays help on a given command',
41 'create': 'create an empty repository at the specified path',
42 'script': 'create an empty change Python script',
43 'script_sql': 'create empty change SQL scripts for given database',
44 'version': 'display the latest version available in a repository',
45 'db_version': 'show the current version of the repository under version control',
46 'source': 'display the Python code for a particular version in this repository',
47 'version_control': 'mark a database as under this repository\'s version control',
48 'upgrade': 'upgrade a database to a later version',
49 'downgrade': 'downgrade a database to an earlier version',
50 'drop_version_control': 'removes version control from a database',
51 'manage': 'creates a Python script that runs Migrate with a set of default values',
52 'test': 'performs the upgrade and downgrade command on the given database',
53 'compare_model_to_db': 'compare MetaData against the current database state',
54 'create_model': 'dump the current database as a Python model to stdout',
55 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
56 'update_db_from_model': 'modify the database to match the structure of the current MetaData',
57 }
58 __all__ = command_desc.keys()
59
60 Repository = repository.Repository
61 ControlledSchema = schema.ControlledSchema
62 VerNum = version.VerNum
63 PythonScript = script_.PythonScript
64 SqlScript = script_.SqlScript
65
66
67 # deprecated
68 def help(cmd=None, **opts):
69 """%prog help COMMAND
70
71 Displays help on a given command.
72 """
73 if cmd is None:
74 raise exceptions.UsageError(None)
75 try:
76 func = globals()[cmd]
77 except:
78 raise exceptions.UsageError(
79 "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
80 ret = func.__doc__
81 if sys.argv[0]:
82 ret = ret.replace('%prog', sys.argv[0])
83 return ret
84
85 @catch_known_errors
86 def create(repository, name, **opts):
87 """%prog create REPOSITORY_PATH NAME [--table=TABLE]
88
89 Create an empty repository at the specified path.
90
91 You can specify the version_table to be used; by default, it is
92 'migrate_version'. This table is created in all version-controlled
93 databases.
94 """
95 repo_path = Repository.create(repository, name, **opts)
96
97
98 @catch_known_errors
99 def script(description, repository, **opts):
100 """%prog script DESCRIPTION REPOSITORY_PATH
101
102 Create an empty change script using the next unused version number
103 appended with the given description.
104
105 For instance, manage.py script "Add initial tables" creates:
106 repository/versions/001_Add_initial_tables.py
107 """
108 repo = Repository(repository)
109 repo.create_script(description, **opts)
110
111
112 @catch_known_errors
113 def script_sql(database, repository, **opts):
114 """%prog script_sql DATABASE REPOSITORY_PATH
115
116 Create empty change SQL scripts for given DATABASE, where DATABASE
117 is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.)
118 or generic ('default').
119
120 For instance, manage.py script_sql postgres creates:
121 repository/versions/001_postgres_upgrade.sql and
122 repository/versions/001_postgres_postgres.sql
123 """
124 repo = Repository(repository)
125 repo.create_script_sql(database, **opts)
126
127
128 def version(repository, **opts):
129 """%prog version REPOSITORY_PATH
130
131 Display the latest version available in a repository.
132 """
133 repo = Repository(repository)
134 return repo.latest
135
136
137 @with_engine
138 def db_version(url, repository, **opts):
139 """%prog db_version URL REPOSITORY_PATH
140
141 Show the current version of the repository with the given
142 connection string, under version control of the specified
143 repository.
144
145 The url should be any valid SQLAlchemy connection string.
146 """
147 engine = opts.pop('engine')
148 schema = ControlledSchema(engine, repository)
149 return schema.version
150
151
152 def source(version, dest=None, repository=None, **opts):
153 """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
154
155 Display the Python code for a particular version in this
156 repository. Save it to the file at DESTINATION or, if omitted,
157 send to stdout.
158 """
159 if repository is None:
160 raise exceptions.UsageError("A repository must be specified")
161 repo = Repository(repository)
162 ret = repo.version(version).script().source()
163 if dest is not None:
164 dest = open(dest, 'w')
165 dest.write(ret)
166 dest.close()
167 ret = None
168 return ret
169
170
171 def upgrade(url, repository, version=None, **opts):
172 """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
173
174 Upgrade a database to a later version.
175
176 This runs the upgrade() function defined in your change scripts.
177
178 By default, the database is updated to the latest available
179 version. You may specify a version instead, if you wish.
180
181 You may preview the Python or SQL code to be executed, rather than
182 actually executing it, using the appropriate 'preview' option.
183 """
184 err = "Cannot upgrade a database of version %s to version %s. "\
185 "Try 'downgrade' instead."
186 return _migrate(url, repository, version, upgrade=True, err=err, **opts)
187
188
189 def downgrade(url, repository, version, **opts):
190 """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
191
192 Downgrade a database to an earlier version.
193
194 This is the reverse of upgrade; this runs the downgrade() function
195 defined in your change scripts.
196
197 You may preview the Python or SQL code to be executed, rather than
198 actually executing it, using the appropriate 'preview' option.
199 """
200 err = "Cannot downgrade a database of version %s to version %s. "\
201 "Try 'upgrade' instead."
202 return _migrate(url, repository, version, upgrade=False, err=err, **opts)
203
204 @with_engine
205 def test(url, repository, **opts):
206 """%prog test URL REPOSITORY_PATH [VERSION]
207
208 Performs the upgrade and downgrade option on the given
209 database. This is not a real test and may leave the database in a
210 bad state. You should therefore better run the test on a copy of
211 your database.
212 """
213 engine = opts.pop('engine')
214 repos = Repository(repository)
215 script = repos.version(None).script()
216
217 # Upgrade
218 log.info("Upgrading...")
219 script.run(engine, 1)
220 log.info("done")
221
222 log.info("Downgrading...")
223 script.run(engine, -1)
224 log.info("done")
225 log.info("Success")
226
227
228 @with_engine
229 def version_control(url, repository, version=None, **opts):
230 """%prog version_control URL REPOSITORY_PATH [VERSION]
231
232 Mark a database as under this repository's version control.
233
234 Once a database is under version control, schema changes should
235 only be done via change scripts in this repository.
236
237 This creates the table version_table in the database.
238
239 The url should be any valid SQLAlchemy connection string.
240
241 By default, the database begins at version 0 and is assumed to be
242 empty. If the database is not empty, you may specify a version at
243 which to begin instead. No attempt is made to verify this
244 version's correctness - the database schema is expected to be
245 identical to what it would be if the database were created from
246 scratch.
247 """
248 engine = opts.pop('engine')
249 ControlledSchema.create(engine, repository, version)
250
251
252 @with_engine
253 def drop_version_control(url, repository, **opts):
254 """%prog drop_version_control URL REPOSITORY_PATH
255
256 Removes version control from a database.
257 """
258 engine = opts.pop('engine')
259 schema = ControlledSchema(engine, repository)
260 schema.drop()
261
262
263 def manage(file, **opts):
264 """%prog manage FILENAME [VARIABLES...]
265
266 Creates a script that runs Migrate with a set of default values.
267
268 For example::
269
270 %prog manage manage.py --repository=/path/to/repository \
271 --url=sqlite:///project.db
272
273 would create the script manage.py. The following two commands
274 would then have exactly the same results::
275
276 python manage.py version
277 %prog version --repository=/path/to/repository
278 """
279 Repository.create_manage_file(file, **opts)
280
281
282 @with_engine
283 def compare_model_to_db(url, repository, model, **opts):
284 """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
285
286 Compare the current model (assumed to be a module level variable
287 of type sqlalchemy.MetaData) against the current database.
288
289 NOTE: This is EXPERIMENTAL.
290 """ # TODO: get rid of EXPERIMENTAL label
291 engine = opts.pop('engine')
292 return ControlledSchema.compare_model_to_db(engine, model, repository)
293
294
295 @with_engine
296 def create_model(url, repository, **opts):
297 """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
298
299 Dump the current database as a Python model to stdout.
300
301 NOTE: This is EXPERIMENTAL.
302 """ # TODO: get rid of EXPERIMENTAL label
303 engine = opts.pop('engine')
304 declarative = opts.get('declarative', False)
305 return ControlledSchema.create_model(engine, repository, declarative)
306
307
308 @catch_known_errors
309 @with_engine
310 def make_update_script_for_model(url, repository, oldmodel, model, **opts):
311 """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
312
313 Create a script changing the old Python model to the new (current)
314 Python model, sending to stdout.
315
316 NOTE: This is EXPERIMENTAL.
317 """ # TODO: get rid of EXPERIMENTAL label
318 engine = opts.pop('engine')
319 return PythonScript.make_update_script_for_model(
320 engine, oldmodel, model, repository, **opts)
321
322
323 @with_engine
324 def update_db_from_model(url, repository, model, **opts):
325 """%prog update_db_from_model URL REPOSITORY_PATH MODEL
326
327 Modify the database to match the structure of the current Python
328 model. This also sets the db_version number to the latest in the
329 repository.
330
331 NOTE: This is EXPERIMENTAL.
332 """ # TODO: get rid of EXPERIMENTAL label
333 engine = opts.pop('engine')
334 schema = ControlledSchema(engine, repository)
335 schema.update_db_from_model(model)
336
337 @with_engine
338 def _migrate(url, repository, version, upgrade, err, **opts):
339 engine = opts.pop('engine')
340 url = str(engine.url)
341 schema = ControlledSchema(engine, repository)
342 version = _migrate_version(schema, version, upgrade, err)
343
344 changeset = schema.changeset(version)
345 for ver, change in changeset:
346 nextver = ver + changeset.step
347 log.info('%s -> %s... ', ver, nextver)
348
349 if opts.get('preview_sql'):
350 if isinstance(change, PythonScript):
351 log.info(change.preview_sql(url, changeset.step, **opts))
352 elif isinstance(change, SqlScript):
353 log.info(change.source())
354
355 elif opts.get('preview_py'):
356 if not isinstance(change, PythonScript):
357 raise exceptions.UsageError("Python source can be only displayed"
358 " for python migration files")
359 source_ver = max(ver, nextver)
360 module = schema.repository.version(source_ver).script().module
361 funcname = upgrade and "upgrade" or "downgrade"
362 func = getattr(module, funcname)
363 log.info(inspect.getsource(func))
364 else:
365 schema.runchange(ver, change, changeset.step)
366 log.info('done')
367
368
369 def _migrate_version(schema, version, upgrade, err):
370 if version is None:
371 return version
372 # Version is specified: ensure we're upgrading in the right direction
373 # (current version < target version for upgrading; reverse for down)
374 version = VerNum(version)
375 cur = schema.version
376 if upgrade is not None:
377 if upgrade:
378 direction = cur <= version
379 else:
380 direction = cur >= version
381 if not direction:
382 raise exceptions.KnownError(err % (cur, version))
383 return version
@@ -0,0 +1,27 b''
1 """
2 Configuration parser module.
3 """
4
5 from ConfigParser import ConfigParser
6
7 from migrate.versioning.config import *
8 from migrate.versioning import pathed
9
10
11 class Parser(ConfigParser):
12 """A project configuration file."""
13
14 def to_dict(self, sections=None):
15 """It's easier to access config values like dictionaries"""
16 return self._sections
17
18
19 class Config(pathed.Pathed, Parser):
20 """Configuration class."""
21
22 def __init__(self, path, *p, **k):
23 """Confirm the config file exists; read it."""
24 self.require_found(path)
25 pathed.Pathed.__init__(self, path)
26 Parser.__init__(self, *p, **k)
27 self.read(path)
@@ -0,0 +1,14 b''
1 #!/usr/bin/python
2 # -*- coding: utf-8 -*-
3
4 from sqlalchemy.util import OrderedDict
5
6
7 __all__ = ['databases', 'operations']
8
9 databases = ('sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird')
10
11 # Map operation names to function names
12 operations = OrderedDict()
13 operations['upgrade'] = 'upgrade'
14 operations['downgrade'] = 'downgrade'
@@ -0,0 +1,254 b''
1 """
2 Code to generate a Python model from a database or differences
3 between a model and database.
4
5 Some of this is borrowed heavily from the AutoCode project at:
6 http://code.google.com/p/sqlautocode/
7 """
8
9 import sys
10 import logging
11
12 import sqlalchemy
13
14 import migrate
15 import migrate.changeset
16
17
18 log = logging.getLogger(__name__)
19 HEADER = """
20 ## File autogenerated by genmodel.py
21
22 from sqlalchemy import *
23 meta = MetaData()
24 """
25
26 DECLARATIVE_HEADER = """
27 ## File autogenerated by genmodel.py
28
29 from sqlalchemy import *
30 from sqlalchemy.ext import declarative
31
32 Base = declarative.declarative_base()
33 """
34
35
36 class ModelGenerator(object):
37
38 def __init__(self, diff, engine, declarative=False):
39 self.diff = diff
40 self.engine = engine
41 self.declarative = declarative
42
43 def column_repr(self, col):
44 kwarg = []
45 if col.key != col.name:
46 kwarg.append('key')
47 if col.primary_key:
48 col.primary_key = True # otherwise it dumps it as 1
49 kwarg.append('primary_key')
50 if not col.nullable:
51 kwarg.append('nullable')
52 if col.onupdate:
53 kwarg.append('onupdate')
54 if col.default:
55 if col.primary_key:
56 # I found that PostgreSQL automatically creates a
57 # default value for the sequence, but let's not show
58 # that.
59 pass
60 else:
61 kwarg.append('default')
62 ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
63
64 # crs: not sure if this is good idea, but it gets rid of extra
65 # u''
66 name = col.name.encode('utf8')
67
68 type_ = col.type
69 for cls in col.type.__class__.__mro__:
70 if cls.__module__ == 'sqlalchemy.types' and \
71 not cls.__name__.isupper():
72 if cls is not type_.__class__:
73 type_ = cls()
74 break
75
76 data = {
77 'name': name,
78 'type': type_,
79 'constraints': ', '.join([repr(cn) for cn in col.constraints]),
80 'args': ks and ks or ''}
81
82 if data['constraints']:
83 if data['args']:
84 data['args'] = ',' + data['args']
85
86 if data['constraints'] or data['args']:
87 data['maybeComma'] = ','
88 else:
89 data['maybeComma'] = ''
90
91 commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
92 commonStuff = commonStuff.strip()
93 data['commonStuff'] = commonStuff
94 if self.declarative:
95 return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
96 else:
97 return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
98
99 def getTableDefn(self, table):
100 out = []
101 tableName = table.name
102 if self.declarative:
103 out.append("class %(table)s(Base):" % {'table': tableName})
104 out.append(" __tablename__ = '%(table)s'" % {'table': tableName})
105 for col in table.columns:
106 out.append(" %s" % self.column_repr(col))
107 else:
108 out.append("%(table)s = Table('%(table)s', meta," % \
109 {'table': tableName})
110 for col in table.columns:
111 out.append(" %s," % self.column_repr(col))
112 out.append(")")
113 return out
114
115 def _get_tables(self,missingA=False,missingB=False,modified=False):
116 to_process = []
117 for bool_,names,metadata in (
118 (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
119 (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
120 (modified,self.diff.tables_different,self.diff.metadataA),
121 ):
122 if bool_:
123 for name in names:
124 yield metadata.tables.get(name)
125
126 def toPython(self):
127 """Assume database is current and model is empty."""
128 out = []
129 if self.declarative:
130 out.append(DECLARATIVE_HEADER)
131 else:
132 out.append(HEADER)
133 out.append("")
134 for table in self._get_tables(missingA=True):
135 out.extend(self.getTableDefn(table))
136 out.append("")
137 return '\n'.join(out)
138
139 def toUpgradeDowngradePython(self, indent=' '):
140 ''' Assume model is most current and database is out-of-date. '''
141 decls = ['from migrate.changeset import schema',
142 'meta = MetaData()']
143 for table in self._get_tables(
144 missingA=True,missingB=True,modified=True
145 ):
146 decls.extend(self.getTableDefn(table))
147
148 upgradeCommands, downgradeCommands = [], []
149 for tableName in self.diff.tables_missing_from_A:
150 upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
151 downgradeCommands.append("%(table)s.create()" % \
152 {'table': tableName})
153 for tableName in self.diff.tables_missing_from_B:
154 upgradeCommands.append("%(table)s.create()" % {'table': tableName})
155 downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
156
157 for tableName in self.diff.tables_different:
158 dbTable = self.diff.metadataB.tables[tableName]
159 missingInDatabase, missingInModel, diffDecl = \
160 self.diff.colDiffs[tableName]
161 for col in missingInDatabase:
162 upgradeCommands.append('%s.columns[%r].create()' % (
163 modelTable, col.name))
164 downgradeCommands.append('%s.columns[%r].drop()' % (
165 modelTable, col.name))
166 for col in missingInModel:
167 upgradeCommands.append('%s.columns[%r].drop()' % (
168 modelTable, col.name))
169 downgradeCommands.append('%s.columns[%r].create()' % (
170 modelTable, col.name))
171 for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
172 upgradeCommands.append(
173 'assert False, "Can\'t alter columns: %s:%s=>%s"',
174 modelTable, modelCol.name, databaseCol.name)
175 downgradeCommands.append(
176 'assert False, "Can\'t alter columns: %s:%s=>%s"',
177 modelTable, modelCol.name, databaseCol.name)
178 pre_command = ' meta.bind = migrate_engine'
179
180 return (
181 '\n'.join(decls),
182 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
183 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
184
185 def _db_can_handle_this_change(self,td):
186 if (td.columns_missing_from_B
187 and not td.columns_missing_from_A
188 and not td.columns_different):
189 # Even sqlite can handle this.
190 return True
191 else:
192 return not self.engine.url.drivername.startswith('sqlite')
193
194 def applyModel(self):
195 """Apply model to current database."""
196
197 meta = sqlalchemy.MetaData(self.engine)
198
199 for table in self._get_tables(missingA=True):
200 table = table.tometadata(meta)
201 table.drop()
202 for table in self._get_tables(missingB=True):
203 table = table.tometadata(meta)
204 table.create()
205 for modelTable in self._get_tables(modified=True):
206 tableName = modelTable.name
207 modelTable = modelTable.tometadata(meta)
208 dbTable = self.diff.metadataB.tables[tableName]
209
210 td = self.diff.tables_different[tableName]
211
212 if self._db_can_handle_this_change(td):
213
214 for col in td.columns_missing_from_B:
215 modelTable.columns[col].create()
216 for col in td.columns_missing_from_A:
217 dbTable.columns[col].drop()
218 # XXX handle column changes here.
219 else:
220 # Sqlite doesn't support drop column, so you have to
221 # do more: create temp table, copy data to it, drop
222 # old table, create new table, copy data back.
223 #
224 # I wonder if this is guaranteed to be unique?
225 tempName = '_temp_%s' % modelTable.name
226
227 def getCopyStatement():
228 preparer = self.engine.dialect.preparer
229 commonCols = []
230 for modelCol in modelTable.columns:
231 if modelCol.name in dbTable.columns:
232 commonCols.append(modelCol.name)
233 commonColsStr = ', '.join(commonCols)
234 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
235 (tableName, commonColsStr, commonColsStr, tempName)
236
237 # Move the data in one transaction, so that we don't
238 # leave the database in a nasty state.
239 connection = self.engine.connect()
240 trans = connection.begin()
241 try:
242 connection.execute(
243 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
244 (tempName, modelTable.name))
245 # make sure the drop takes place inside our
246 # transaction with the bind parameter
247 modelTable.drop(bind=connection)
248 modelTable.create(bind=connection)
249 connection.execute(getCopyStatement())
250 connection.execute('DROP TABLE %s' % tempName)
251 trans.commit()
252 except:
253 trans.rollback()
254 raise
@@ -0,0 +1,100 b''
1 """
2 Script to migrate repository from sqlalchemy <= 0.4.4 to the new
3 repository schema. This shouldn't use any other migrate modules, so
4 that it can work in any version.
5 """
6
7 import os
8 import sys
9 import logging
10
11 log = logging.getLogger(__name__)
12
13
14 def usage():
15 """Gives usage information."""
16 print """Usage: %(prog)s repository-to-migrate
17
18 Upgrade your repository to the new flat format.
19
20 NOTE: You should probably make a backup before running this.
21 """ % {'prog': sys.argv[0]}
22
23 sys.exit(1)
24
25
26 def delete_file(filepath):
27 """Deletes a file and prints a message."""
28 log.info('Deleting file: %s' % filepath)
29 os.remove(filepath)
30
31
32 def move_file(src, tgt):
33 """Moves a file and prints a message."""
34 log.info('Moving file %s to %s' % (src, tgt))
35 if os.path.exists(tgt):
36 raise Exception(
37 'Cannot move file %s because target %s already exists' % \
38 (src, tgt))
39 os.rename(src, tgt)
40
41
42 def delete_directory(dirpath):
43 """Delete a directory and print a message."""
44 log.info('Deleting directory: %s' % dirpath)
45 os.rmdir(dirpath)
46
47
48 def migrate_repository(repos):
49 """Does the actual migration to the new repository format."""
50 log.info('Migrating repository at: %s to new format' % repos)
51 versions = '%s/versions' % repos
52 dirs = os.listdir(versions)
53 # Only use int's in list.
54 numdirs = [int(dirname) for dirname in dirs if dirname.isdigit()]
55 numdirs.sort() # Sort list.
56 for dirname in numdirs:
57 origdir = '%s/%s' % (versions, dirname)
58 log.info('Working on directory: %s' % origdir)
59 files = os.listdir(origdir)
60 files.sort()
61 for filename in files:
62 # Delete compiled Python files.
63 if filename.endswith('.pyc') or filename.endswith('.pyo'):
64 delete_file('%s/%s' % (origdir, filename))
65
66 # Delete empty __init__.py files.
67 origfile = '%s/__init__.py' % origdir
68 if os.path.exists(origfile) and len(open(origfile).read()) == 0:
69 delete_file(origfile)
70
71 # Move sql upgrade scripts.
72 if filename.endswith('.sql'):
73 version, dbms, operation = filename.split('.', 3)[0:3]
74 origfile = '%s/%s' % (origdir, filename)
75 # For instance: 2.postgres.upgrade.sql ->
76 # 002_postgres_upgrade.sql
77 tgtfile = '%s/%03d_%s_%s.sql' % (
78 versions, int(version), dbms, operation)
79 move_file(origfile, tgtfile)
80
81 # Move Python upgrade script.
82 pyfile = '%s.py' % dirname
83 pyfilepath = '%s/%s' % (origdir, pyfile)
84 if os.path.exists(pyfilepath):
85 tgtfile = '%s/%03d.py' % (versions, int(dirname))
86 move_file(pyfilepath, tgtfile)
87
88 # Try to remove directory. Will fail if it's not empty.
89 delete_directory(origdir)
90
91
92 def main():
93 """Main function to be called when using this script."""
94 if len(sys.argv) != 2:
95 usage()
96 migrate_repository(sys.argv[1])
97
98
99 if __name__ == '__main__':
100 main()
@@ -0,0 +1,75 b''
1 """
2 A path/directory class.
3 """
4
5 import os
6 import shutil
7 import logging
8
9 from migrate import exceptions
10 from migrate.versioning.config import *
11 from migrate.versioning.util import KeyedInstance
12
13
14 log = logging.getLogger(__name__)
15
16 class Pathed(KeyedInstance):
17 """
18 A class associated with a path/directory tree.
19
20 Only one instance of this class may exist for a particular file;
21 __new__ will return an existing instance if possible
22 """
23 parent = None
24
25 @classmethod
26 def _key(cls, path):
27 return str(path)
28
29 def __init__(self, path):
30 self.path = path
31 if self.__class__.parent is not None:
32 self._init_parent(path)
33
34 def _init_parent(self, path):
35 """Try to initialize this object's parent, if it has one"""
36 parent_path = self.__class__._parent_path(path)
37 self.parent = self.__class__.parent(parent_path)
38 log.debug("Getting parent %r:%r" % (self.__class__.parent, parent_path))
39 self.parent._init_child(path, self)
40
41 def _init_child(self, child, path):
42 """Run when a child of this object is initialized.
43
44 Parameters: the child object; the path to this object (its
45 parent)
46 """
47
48 @classmethod
49 def _parent_path(cls, path):
50 """
51 Fetch the path of this object's parent from this object's path.
52 """
53 # os.path.dirname(), but strip directories like files (like
54 # unix basename)
55 #
56 # Treat directories like files...
57 if path[-1] == '/':
58 path = path[:-1]
59 ret = os.path.dirname(path)
60 return ret
61
62 @classmethod
63 def require_notfound(cls, path):
64 """Ensures a given path does not already exist"""
65 if os.path.exists(path):
66 raise exceptions.PathFoundError(path)
67
68 @classmethod
69 def require_found(cls, path):
70 """Ensures a given path already exists"""
71 if not os.path.exists(path):
72 raise exceptions.PathNotFoundError(path)
73
74 def __str__(self):
75 return self.path
@@ -0,0 +1,231 b''
1 """
2 SQLAlchemy migrate repository management.
3 """
4 import os
5 import shutil
6 import string
7 import logging
8
9 from pkg_resources import resource_filename
10 from tempita import Template as TempitaTemplate
11
12 from migrate import exceptions
13 from migrate.versioning import version, pathed, cfgparse
14 from migrate.versioning.template import Template
15 from migrate.versioning.config import *
16
17
18 log = logging.getLogger(__name__)
19
20 class Changeset(dict):
21 """A collection of changes to be applied to a database.
22
23 Changesets are bound to a repository and manage a set of
24 scripts from that repository.
25
26 Behaves like a dict, for the most part. Keys are ordered based on step value.
27 """
28
29 def __init__(self, start, *changes, **k):
30 """
31 Give a start version; step must be explicitly stated.
32 """
33 self.step = k.pop('step', 1)
34 self.start = version.VerNum(start)
35 self.end = self.start
36 for change in changes:
37 self.add(change)
38
39 def __iter__(self):
40 return iter(self.items())
41
42 def keys(self):
43 """
44 In a series of upgrades x -> y, keys are version x. Sorted.
45 """
46 ret = super(Changeset, self).keys()
47 # Reverse order if downgrading
48 ret.sort(reverse=(self.step < 1))
49 return ret
50
51 def values(self):
52 return [self[k] for k in self.keys()]
53
54 def items(self):
55 return zip(self.keys(), self.values())
56
57 def add(self, change):
58 """Add new change to changeset"""
59 key = self.end
60 self.end += self.step
61 self[key] = change
62
63 def run(self, *p, **k):
64 """Run the changeset scripts"""
65 for version, script in self:
66 script.run(*p, **k)
67
68
69 class Repository(pathed.Pathed):
70 """A project's change script repository"""
71
72 _config = 'migrate.cfg'
73 _versions = 'versions'
74
75 def __init__(self, path):
76 log.debug('Loading repository %s...' % path)
77 self.verify(path)
78 super(Repository, self).__init__(path)
79 self.config = cfgparse.Config(os.path.join(self.path, self._config))
80 self.versions = version.Collection(os.path.join(self.path,
81 self._versions))
82 log.debug('Repository %s loaded successfully' % path)
83 log.debug('Config: %r' % self.config.to_dict())
84
85 @classmethod
86 def verify(cls, path):
87 """
88 Ensure the target path is a valid repository.
89
90 :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
91 """
92 # Ensure the existence of required files
93 try:
94 cls.require_found(path)
95 cls.require_found(os.path.join(path, cls._config))
96 cls.require_found(os.path.join(path, cls._versions))
97 except exceptions.PathNotFoundError, e:
98 raise exceptions.InvalidRepositoryError(path)
99
100 @classmethod
101 def prepare_config(cls, tmpl_dir, name, options=None):
102 """
103 Prepare a project configuration file for a new project.
104
105 :param tmpl_dir: Path to Repository template
106 :param config_file: Name of the config file in Repository template
107 :param name: Repository name
108 :type tmpl_dir: string
109 :type config_file: string
110 :type name: string
111 :returns: Populated config file
112 """
113 if options is None:
114 options = {}
115 options.setdefault('version_table', 'migrate_version')
116 options.setdefault('repository_id', name)
117 options.setdefault('required_dbs', [])
118
119 tmpl = open(os.path.join(tmpl_dir, cls._config)).read()
120 ret = TempitaTemplate(tmpl).substitute(options)
121
122 # cleanup
123 del options['__template_name__']
124
125 return ret
126
127 @classmethod
128 def create(cls, path, name, **opts):
129 """Create a repository at a specified path"""
130 cls.require_notfound(path)
131 theme = opts.pop('templates_theme', None)
132 t_path = opts.pop('templates_path', None)
133
134 # Create repository
135 tmpl_dir = Template(t_path).get_repository(theme=theme)
136 shutil.copytree(tmpl_dir, path)
137
138 # Edit config defaults
139 config_text = cls.prepare_config(tmpl_dir, name, options=opts)
140 fd = open(os.path.join(path, cls._config), 'w')
141 fd.write(config_text)
142 fd.close()
143
144 opts['repository_name'] = name
145
146 # Create a management script
147 manager = os.path.join(path, 'manage.py')
148 Repository.create_manage_file(manager, templates_theme=theme,
149 templates_path=t_path, **opts)
150
151 return cls(path)
152
153 def create_script(self, description, **k):
154 """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
155 self.versions.create_new_python_version(description, **k)
156
157 def create_script_sql(self, database, **k):
158 """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
159 self.versions.create_new_sql_version(database, **k)
160
161 @property
162 def latest(self):
163 """API to :attr:`migrate.versioning.version.Collection.latest`"""
164 return self.versions.latest
165
166 @property
167 def version_table(self):
168 """Returns version_table name specified in config"""
169 return self.config.get('db_settings', 'version_table')
170
171 @property
172 def id(self):
173 """Returns repository id specified in config"""
174 return self.config.get('db_settings', 'repository_id')
175
176 def version(self, *p, **k):
177 """API to :attr:`migrate.versioning.version.Collection.version`"""
178 return self.versions.version(*p, **k)
179
180 @classmethod
181 def clear(cls):
182 # TODO: deletes repo
183 super(Repository, cls).clear()
184 version.Collection.clear()
185
186 def changeset(self, database, start, end=None):
187 """Create a changeset to migrate this database from ver. start to end/latest.
188
189 :param database: name of database to generate changeset
190 :param start: version to start at
191 :param end: version to end at (latest if None given)
192 :type database: string
193 :type start: int
194 :type end: int
195 :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
196 """
197 start = version.VerNum(start)
198
199 if end is None:
200 end = self.latest
201 else:
202 end = version.VerNum(end)
203
204 if start <= end:
205 step = 1
206 range_mod = 1
207 op = 'upgrade'
208 else:
209 step = -1
210 range_mod = 0
211 op = 'downgrade'
212
213 versions = range(start + range_mod, end + range_mod, step)
214 changes = [self.version(v).script(database, op) for v in versions]
215 ret = Changeset(start, step=step, *changes)
216 return ret
217
218 @classmethod
219 def create_manage_file(cls, file_, **opts):
220 """Create a project management script (manage.py)
221
222 :param file_: Destination file to be written
223 :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
224 """
225 mng_file = Template(opts.pop('templates_path', None))\
226 .get_manage(theme=opts.pop('templates_theme', None))
227
228 tmpl = open(mng_file).read()
229 fd = open(file_, 'w')
230 fd.write(TempitaTemplate(tmpl).substitute(opts))
231 fd.close()
@@ -0,0 +1,213 b''
1 """
2 Database schema version management.
3 """
4 import sys
5 import logging
6
7 from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
8 create_engine)
9 from sqlalchemy.sql import and_
10 from sqlalchemy import exceptions as sa_exceptions
11 from sqlalchemy.sql import bindparam
12
13 from migrate import exceptions
14 from migrate.versioning import genmodel, schemadiff
15 from migrate.versioning.repository import Repository
16 from migrate.versioning.util import load_model
17 from migrate.versioning.version import VerNum
18
19
20 log = logging.getLogger(__name__)
21
22 class ControlledSchema(object):
23 """A database under version control"""
24
25 def __init__(self, engine, repository):
26 if isinstance(repository, basestring):
27 repository = Repository(repository)
28 self.engine = engine
29 self.repository = repository
30 self.meta = MetaData(engine)
31 self.load()
32
33 def __eq__(self, other):
34 """Compare two schemas by repositories and versions"""
35 return (self.repository is other.repository \
36 and self.version == other.version)
37
38 def load(self):
39 """Load controlled schema version info from DB"""
40 tname = self.repository.version_table
41 try:
42 if not hasattr(self, 'table') or self.table is None:
43 self.table = Table(tname, self.meta, autoload=True)
44
45 result = self.engine.execute(self.table.select(
46 self.table.c.repository_id == str(self.repository.id)))
47
48 data = list(result)[0]
49 except:
50 cls, exc, tb = sys.exc_info()
51 raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
52
53 self.version = data['version']
54 return data
55
56 def drop(self):
57 """
58 Remove version control from a database.
59 """
60 try:
61 self.table.drop()
62 except (sa_exceptions.SQLError):
63 raise exceptions.DatabaseNotControlledError(str(self.table))
64
65 def changeset(self, version=None):
66 """API to Changeset creation.
67
68 Uses self.version for start version and engine.name
69 to get database name.
70 """
71 database = self.engine.name
72 start_ver = self.version
73 changeset = self.repository.changeset(database, start_ver, version)
74 return changeset
75
76 def runchange(self, ver, change, step):
77 startver = ver
78 endver = ver + step
79 # Current database version must be correct! Don't run if corrupt!
80 if self.version != startver:
81 raise exceptions.InvalidVersionError("%s is not %s" % \
82 (self.version, startver))
83 # Run the change
84 change.run(self.engine, step)
85
86 # Update/refresh database version
87 self.update_repository_table(startver, endver)
88 self.load()
89
90 def update_repository_table(self, startver, endver):
91 """Update version_table with new information"""
92 update = self.table.update(and_(self.table.c.version == int(startver),
93 self.table.c.repository_id == str(self.repository.id)))
94 self.engine.execute(update, version=int(endver))
95
96 def upgrade(self, version=None):
97 """
98 Upgrade (or downgrade) to a specified version, or latest version.
99 """
100 changeset = self.changeset(version)
101 for ver, change in changeset:
102 self.runchange(ver, change, changeset.step)
103
104 def update_db_from_model(self, model):
105 """
106 Modify the database to match the structure of the current Python model.
107 """
108 model = load_model(model)
109
110 diff = schemadiff.getDiffOfModelAgainstDatabase(
111 model, self.engine, excludeTables=[self.repository.version_table]
112 )
113 genmodel.ModelGenerator(diff,self.engine).applyModel()
114
115 self.update_repository_table(self.version, int(self.repository.latest))
116
117 self.load()
118
119 @classmethod
120 def create(cls, engine, repository, version=None):
121 """
122 Declare a database to be under a repository's version control.
123
124 :raises: :exc:`DatabaseAlreadyControlledError`
125 :returns: :class:`ControlledSchema`
126 """
127 # Confirm that the version # is valid: positive, integer,
128 # exists in repos
129 if isinstance(repository, basestring):
130 repository = Repository(repository)
131 version = cls._validate_version(repository, version)
132 table = cls._create_table_version(engine, repository, version)
133 # TODO: history table
134 # Load repository information and return
135 return cls(engine, repository)
136
137 @classmethod
138 def _validate_version(cls, repository, version):
139 """
140 Ensures this is a valid version number for this repository.
141
142 :raises: :exc:`InvalidVersionError` if invalid
143 :return: valid version number
144 """
145 if version is None:
146 version = 0
147 try:
148 version = VerNum(version) # raises valueerror
149 if version < 0 or version > repository.latest:
150 raise ValueError()
151 except ValueError:
152 raise exceptions.InvalidVersionError(version)
153 return version
154
155 @classmethod
156 def _create_table_version(cls, engine, repository, version):
157 """
158 Creates the versioning table in a database.
159
160 :raises: :exc:`DatabaseAlreadyControlledError`
161 """
162 # Create tables
163 tname = repository.version_table
164 meta = MetaData(engine)
165
166 table = Table(
167 tname, meta,
168 Column('repository_id', String(250), primary_key=True),
169 Column('repository_path', Text),
170 Column('version', Integer), )
171
172 # there can be multiple repositories/schemas in the same db
173 if not table.exists():
174 table.create()
175
176 # test for existing repository_id
177 s = table.select(table.c.repository_id == bindparam("repository_id"))
178 result = engine.execute(s, repository_id=repository.id)
179 if result.fetchone():
180 raise exceptions.DatabaseAlreadyControlledError
181
182 # Insert data
183 engine.execute(table.insert().values(
184 repository_id=repository.id,
185 repository_path=repository.path,
186 version=int(version)))
187 return table
188
189 @classmethod
190 def compare_model_to_db(cls, engine, model, repository):
191 """
192 Compare the current model against the current database.
193 """
194 if isinstance(repository, basestring):
195 repository = Repository(repository)
196 model = load_model(model)
197
198 diff = schemadiff.getDiffOfModelAgainstDatabase(
199 model, engine, excludeTables=[repository.version_table])
200 return diff
201
202 @classmethod
203 def create_model(cls, engine, repository, declarative=False):
204 """
205 Dump the current database as a Python model.
206 """
207 if isinstance(repository, basestring):
208 repository = Repository(repository)
209
210 diff = schemadiff.getDiffOfModelAgainstDatabase(
211 MetaData(), engine, excludeTables=[repository.version_table]
212 )
213 return genmodel.ModelGenerator(diff, engine, declarative).toPython()
@@ -0,0 +1,285 b''
1 """
2 Schema differencing support.
3 """
4
5 import logging
6 import sqlalchemy
7
8 from migrate.changeset import SQLA_06
9 from sqlalchemy.types import Float
10
11 log = logging.getLogger(__name__)
12
13 def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None):
14 """
15 Return differences of model against database.
16
17 :return: object which will evaluate to :keyword:`True` if there \
18 are differences else :keyword:`False`.
19 """
20 return SchemaDiff(metadata,
21 sqlalchemy.MetaData(engine, reflect=True),
22 labelA='model',
23 labelB='database',
24 excludeTables=excludeTables)
25
26
27 def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None):
28 """
29 Return differences of model against another model.
30
31 :return: object which will evaluate to :keyword:`True` if there \
32 are differences else :keyword:`False`.
33 """
34 return SchemaDiff(metadataA, metadataB, excludeTables)
35
36
37 class ColDiff(object):
38 """
39 Container for differences in one :class:`~sqlalchemy.schema.Column`
40 between two :class:`~sqlalchemy.schema.Table` instances, ``A``
41 and ``B``.
42
43 .. attribute:: col_A
44
45 The :class:`~sqlalchemy.schema.Column` object for A.
46
47 .. attribute:: col_B
48
49 The :class:`~sqlalchemy.schema.Column` object for B.
50
51 .. attribute:: type_A
52
53 The most generic type of the :class:`~sqlalchemy.schema.Column`
54 object in A.
55
56 .. attribute:: type_B
57
58 The most generic type of the :class:`~sqlalchemy.schema.Column`
59 object in A.
60
61 """
62
63 diff = False
64
65 def __init__(self,col_A,col_B):
66 self.col_A = col_A
67 self.col_B = col_B
68
69 self.type_A = col_A.type
70 self.type_B = col_B.type
71
72 self.affinity_A = self.type_A._type_affinity
73 self.affinity_B = self.type_B._type_affinity
74
75 if self.affinity_A is not self.affinity_B:
76 self.diff = True
77 return
78
79 if isinstance(self.type_A,Float) or isinstance(self.type_B,Float):
80 if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)):
81 self.diff=True
82 return
83
84 for attr in ('precision','scale','length'):
85 A = getattr(self.type_A,attr,None)
86 B = getattr(self.type_B,attr,None)
87 if not (A is None or B is None) and A!=B:
88 self.diff=True
89 return
90
91 def __nonzero__(self):
92 return self.diff
93
94 class TableDiff(object):
95 """
96 Container for differences in one :class:`~sqlalchemy.schema.Table`
97 between two :class:`~sqlalchemy.schema.MetaData` instances, ``A``
98 and ``B``.
99
100 .. attribute:: columns_missing_from_A
101
102 A sequence of column names that were found in B but weren't in
103 A.
104
105 .. attribute:: columns_missing_from_B
106
107 A sequence of column names that were found in A but weren't in
108 B.
109
110 .. attribute:: columns_different
111
112 A dictionary containing information about columns that were
113 found to be different.
114 It maps column names to a :class:`ColDiff` objects describing the
115 differences found.
116 """
117 __slots__ = (
118 'columns_missing_from_A',
119 'columns_missing_from_B',
120 'columns_different',
121 )
122
123 def __nonzero__(self):
124 return bool(
125 self.columns_missing_from_A or
126 self.columns_missing_from_B or
127 self.columns_different
128 )
129
130 class SchemaDiff(object):
131 """
132 Compute the difference between two :class:`~sqlalchemy.schema.MetaData`
133 objects.
134
135 The string representation of a :class:`SchemaDiff` will summarise
136 the changes found between the two
137 :class:`~sqlalchemy.schema.MetaData` objects.
138
139 The length of a :class:`SchemaDiff` will give the number of
140 changes found, enabling it to be used much like a boolean in
141 expressions.
142
143 :param metadataA:
144 First :class:`~sqlalchemy.schema.MetaData` to compare.
145
146 :param metadataB:
147 Second :class:`~sqlalchemy.schema.MetaData` to compare.
148
149 :param labelA:
150 The label to use in messages about the first
151 :class:`~sqlalchemy.schema.MetaData`.
152
153 :param labelB:
154 The label to use in messages about the second
155 :class:`~sqlalchemy.schema.MetaData`.
156
157 :param excludeTables:
158 A sequence of table names to exclude.
159
160 .. attribute:: tables_missing_from_A
161
162 A sequence of table names that were found in B but weren't in
163 A.
164
165 .. attribute:: tables_missing_from_B
166
167 A sequence of table names that were found in A but weren't in
168 B.
169
170 .. attribute:: tables_different
171
172 A dictionary containing information about tables that were found
173 to be different.
174 It maps table names to a :class:`TableDiff` objects describing the
175 differences found.
176 """
177
178 def __init__(self,
179 metadataA, metadataB,
180 labelA='metadataA',
181 labelB='metadataB',
182 excludeTables=None):
183
184 self.metadataA, self.metadataB = metadataA, metadataB
185 self.labelA, self.labelB = labelA, labelB
186 self.label_width = max(len(labelA),len(labelB))
187 excludeTables = set(excludeTables or [])
188
189 A_table_names = set(metadataA.tables.keys())
190 B_table_names = set(metadataB.tables.keys())
191
192 self.tables_missing_from_A = sorted(
193 B_table_names - A_table_names - excludeTables
194 )
195 self.tables_missing_from_B = sorted(
196 A_table_names - B_table_names - excludeTables
197 )
198
199 self.tables_different = {}
200 for table_name in A_table_names.intersection(B_table_names):
201
202 td = TableDiff()
203
204 A_table = metadataA.tables[table_name]
205 B_table = metadataB.tables[table_name]
206
207 A_column_names = set(A_table.columns.keys())
208 B_column_names = set(B_table.columns.keys())
209
210 td.columns_missing_from_A = sorted(
211 B_column_names - A_column_names
212 )
213
214 td.columns_missing_from_B = sorted(
215 A_column_names - B_column_names
216 )
217
218 td.columns_different = {}
219
220 for col_name in A_column_names.intersection(B_column_names):
221
222 cd = ColDiff(
223 A_table.columns.get(col_name),
224 B_table.columns.get(col_name)
225 )
226
227 if cd:
228 td.columns_different[col_name]=cd
229
230 # XXX - index and constraint differences should
231 # be checked for here
232
233 if td:
234 self.tables_different[table_name]=td
235
236 def __str__(self):
237 ''' Summarize differences. '''
238 out = []
239 column_template =' %%%is: %%r' % self.label_width
240
241 for names,label in (
242 (self.tables_missing_from_A,self.labelA),
243 (self.tables_missing_from_B,self.labelB),
244 ):
245 if names:
246 out.append(
247 ' tables missing from %s: %s' % (
248 label,', '.join(sorted(names))
249 )
250 )
251
252 for name,td in sorted(self.tables_different.items()):
253 out.append(
254 ' table with differences: %s' % name
255 )
256 for names,label in (
257 (td.columns_missing_from_A,self.labelA),
258 (td.columns_missing_from_B,self.labelB),
259 ):
260 if names:
261 out.append(
262 ' %s missing these columns: %s' % (
263 label,', '.join(sorted(names))
264 )
265 )
266 for name,cd in td.columns_different.items():
267 out.append(' column with differences: %s' % name)
268 out.append(column_template % (self.labelA,cd.col_A))
269 out.append(column_template % (self.labelB,cd.col_B))
270
271 if out:
272 out.insert(0, 'Schema diffs:')
273 return '\n'.join(out)
274 else:
275 return 'No schema diffs'
276
277 def __len__(self):
278 """
279 Used in bool evaluation, return of 0 means no diffs.
280 """
281 return (
282 len(self.tables_missing_from_A) +
283 len(self.tables_missing_from_B) +
284 len(self.tables_different)
285 )
@@ -0,0 +1,6 b''
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from migrate.versioning.script.base import BaseScript
5 from migrate.versioning.script.py import PythonScript
6 from migrate.versioning.script.sql import SqlScript
@@ -0,0 +1,57 b''
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 import logging
4
5 from migrate import exceptions
6 from migrate.versioning.config import operations
7 from migrate.versioning import pathed
8
9
10 log = logging.getLogger(__name__)
11
12 class BaseScript(pathed.Pathed):
13 """Base class for other types of scripts.
14 All scripts have the following properties:
15
16 source (script.source())
17 The source code of the script
18 version (script.version())
19 The version number of the script
20 operations (script.operations())
21 The operations defined by the script: upgrade(), downgrade() or both.
22 Returns a tuple of operations.
23 Can also check for an operation with ex. script.operation(Script.ops.up)
24 """ # TODO: sphinxfy this and implement it correctly
25
26 def __init__(self, path):
27 log.debug('Loading script %s...' % path)
28 self.verify(path)
29 super(BaseScript, self).__init__(path)
30 log.debug('Script %s loaded successfully' % path)
31
32 @classmethod
33 def verify(cls, path):
34 """Ensure this is a valid script
35 This version simply ensures the script file's existence
36
37 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
38 """
39 try:
40 cls.require_found(path)
41 except:
42 raise exceptions.InvalidScriptError(path)
43
44 def source(self):
45 """:returns: source code of the script.
46 :rtype: string
47 """
48 fd = open(self.path)
49 ret = fd.read()
50 fd.close()
51 return ret
52
53 def run(self, engine):
54 """Core of each BaseScript subclass.
55 This method executes the script.
56 """
57 raise NotImplementedError()
@@ -0,0 +1,159 b''
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import shutil
5 import warnings
6 import logging
7 from StringIO import StringIO
8
9 import migrate
10 from migrate.versioning import genmodel, schemadiff
11 from migrate.versioning.config import operations
12 from migrate.versioning.template import Template
13 from migrate.versioning.script import base
14 from migrate.versioning.util import import_path, load_model, with_engine
15 from migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
16
17 log = logging.getLogger(__name__)
18 __all__ = ['PythonScript']
19
20
21 class PythonScript(base.BaseScript):
22 """Base for Python scripts"""
23
24 @classmethod
25 def create(cls, path, **opts):
26 """Create an empty migration script at specified path
27
28 :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
29 cls.require_notfound(path)
30
31 src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
32 shutil.copy(src, path)
33
34 return cls(path)
35
36 @classmethod
37 def make_update_script_for_model(cls, engine, oldmodel,
38 model, repository, **opts):
39 """Create a migration script based on difference between two SA models.
40
41 :param repository: path to migrate repository
42 :param oldmodel: dotted.module.name:SAClass or SAClass object
43 :param model: dotted.module.name:SAClass or SAClass object
44 :param engine: SQLAlchemy engine
45 :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
46 :type oldmodel: string or Class
47 :type model: string or Class
48 :type engine: Engine instance
49 :returns: Upgrade / Downgrade script
50 :rtype: string
51 """
52
53 if isinstance(repository, basestring):
54 # oh dear, an import cycle!
55 from migrate.versioning.repository import Repository
56 repository = Repository(repository)
57
58 oldmodel = load_model(oldmodel)
59 model = load_model(model)
60
61 # Compute differences.
62 diff = schemadiff.getDiffOfModelAgainstModel(
63 oldmodel,
64 model,
65 excludeTables=[repository.version_table])
66 # TODO: diff can be False (there is no difference?)
67 decls, upgradeCommands, downgradeCommands = \
68 genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython()
69
70 # Store differences into file.
71 src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
72 f = open(src)
73 contents = f.read()
74 f.close()
75
76 # generate source
77 search = 'def upgrade(migrate_engine):'
78 contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
79 if upgradeCommands:
80 contents = contents.replace(' pass', upgradeCommands, 1)
81 if downgradeCommands:
82 contents = contents.replace(' pass', downgradeCommands, 1)
83 return contents
84
85 @classmethod
86 def verify_module(cls, path):
87 """Ensure path is a valid script
88
89 :param path: Script location
90 :type path: string
91 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
92 :returns: Python module
93 """
94 # Try to import and get the upgrade() func
95 module = import_path(path)
96 try:
97 assert callable(module.upgrade)
98 except Exception, e:
99 raise InvalidScriptError(path + ': %s' % str(e))
100 return module
101
102 def preview_sql(self, url, step, **args):
103 """Mocks SQLAlchemy Engine to store all executed calls in a string
104 and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
105
106 :returns: SQL file
107 """
108 buf = StringIO()
109 args['engine_arg_strategy'] = 'mock'
110 args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
111
112 @with_engine
113 def go(url, step, **kw):
114 engine = kw.pop('engine')
115 self.run(engine, step)
116 return buf.getvalue()
117
118 return go(url, step, **args)
119
120 def run(self, engine, step):
121 """Core method of Script file.
122 Exectues :func:`update` or :func:`downgrade` functions
123
124 :param engine: SQLAlchemy Engine
125 :param step: Operation to run
126 :type engine: string
127 :type step: int
128 """
129 if step > 0:
130 op = 'upgrade'
131 elif step < 0:
132 op = 'downgrade'
133 else:
134 raise ScriptError("%d is not a valid step" % step)
135
136 funcname = base.operations[op]
137 script_func = self._func(funcname)
138
139 try:
140 script_func(engine)
141 except TypeError:
142 warnings.warn("upgrade/downgrade functions must accept engine"
143 " parameter (since version > 0.5.4)", MigrateDeprecationWarning)
144 raise
145
146 @property
147 def module(self):
148 """Calls :meth:`migrate.versioning.script.py.verify_module`
149 and returns it.
150 """
151 if not hasattr(self, '_module'):
152 self._module = self.verify_module(self.path)
153 return self._module
154
155 def _func(self, funcname):
156 if not hasattr(self.module, funcname):
157 msg = "Function '%s' is not defined in this script"
158 raise ScriptError(msg % funcname)
159 return getattr(self.module, funcname)
@@ -0,0 +1,49 b''
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 import logging
4 import shutil
5
6 from migrate.versioning.script import base
7 from migrate.versioning.template import Template
8
9
10 log = logging.getLogger(__name__)
11
12 class SqlScript(base.BaseScript):
13 """A file containing plain SQL statements."""
14
15 @classmethod
16 def create(cls, path, **opts):
17 """Create an empty migration script at specified path
18
19 :returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
20 cls.require_notfound(path)
21
22 src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None))
23 shutil.copy(src, path)
24 return cls(path)
25
26 # TODO: why is step parameter even here?
27 def run(self, engine, step=None, executemany=True):
28 """Runs SQL script through raw dbapi execute call"""
29 text = self.source()
30 # Don't rely on SA's autocommit here
31 # (SA uses .startswith to check if a commit is needed. What if script
32 # starts with a comment?)
33 conn = engine.connect()
34 try:
35 trans = conn.begin()
36 try:
37 # HACK: SQLite doesn't allow multiple statements through
38 # its execute() method, but it provides executescript() instead
39 dbapi = conn.engine.raw_connection()
40 if executemany and getattr(dbapi, 'executescript', None):
41 dbapi.executescript(text)
42 else:
43 conn.execute(text)
44 trans.commit()
45 except:
46 trans.rollback()
47 raise
48 finally:
49 conn.close()
@@ -0,0 +1,215 b''
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 """The migrate command-line tool."""
5
6 import sys
7 import inspect
8 import logging
9 from optparse import OptionParser, BadOptionError
10
11 from migrate import exceptions
12 from migrate.versioning import api
13 from migrate.versioning.config import *
14 from migrate.versioning.util import asbool
15
16
17 alias = dict(
18 s=api.script,
19 vc=api.version_control,
20 dbv=api.db_version,
21 v=api.version,
22 )
23
24 def alias_setup():
25 global alias
26 for key, val in alias.iteritems():
27 setattr(api, key, val)
28 alias_setup()
29
30
31 class PassiveOptionParser(OptionParser):
32
33 def _process_args(self, largs, rargs, values):
34 """little hack to support all --some_option=value parameters"""
35
36 while rargs:
37 arg = rargs[0]
38 if arg == "--":
39 del rargs[0]
40 return
41 elif arg[0:2] == "--":
42 # if parser does not know about the option
43 # pass it along (make it anonymous)
44 try:
45 opt = arg.split('=', 1)[0]
46 self._match_long_opt(opt)
47 except BadOptionError:
48 largs.append(arg)
49 del rargs[0]
50 else:
51 self._process_long_opt(rargs, values)
52 elif arg[:1] == "-" and len(arg) > 1:
53 self._process_short_opts(rargs, values)
54 elif self.allow_interspersed_args:
55 largs.append(arg)
56 del rargs[0]
57
58 def main(argv=None, **kwargs):
59 """Shell interface to :mod:`migrate.versioning.api`.
60
61 kwargs are default options that can be overriden with passing
62 --some_option as command line option
63
64 :param disable_logging: Let migrate configure logging
65 :type disable_logging: bool
66 """
67 if argv is not None:
68 argv = argv
69 else:
70 argv = list(sys.argv[1:])
71 commands = list(api.__all__)
72 commands.sort()
73
74 usage = """%%prog COMMAND ...
75
76 Available commands:
77 %s
78
79 Enter "%%prog help COMMAND" for information on a particular command.
80 """ % '\n\t'.join(["%s - %s" % (command.ljust(28),
81 api.command_desc.get(command)) for command in commands])
82
83 parser = PassiveOptionParser(usage=usage)
84 parser.add_option("-d", "--debug",
85 action="store_true",
86 dest="debug",
87 default=False,
88 help="Shortcut to turn on DEBUG mode for logging")
89 parser.add_option("-q", "--disable_logging",
90 action="store_true",
91 dest="disable_logging",
92 default=False,
93 help="Use this option to disable logging configuration")
94 help_commands = ['help', '-h', '--help']
95 HELP = False
96
97 try:
98 command = argv.pop(0)
99 if command in help_commands:
100 HELP = True
101 command = argv.pop(0)
102 except IndexError:
103 parser.print_help()
104 return
105
106 command_func = getattr(api, command, None)
107 if command_func is None or command.startswith('_'):
108 parser.error("Invalid command %s" % command)
109
110 parser.set_usage(inspect.getdoc(command_func))
111 f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
112 for arg in f_args:
113 parser.add_option(
114 "--%s" % arg,
115 dest=arg,
116 action='store',
117 type="string")
118
119 # display help of the current command
120 if HELP:
121 parser.print_help()
122 return
123
124 options, args = parser.parse_args(argv)
125
126 # override kwargs with anonymous parameters
127 override_kwargs = dict()
128 for arg in list(args):
129 if arg.startswith('--'):
130 args.remove(arg)
131 if '=' in arg:
132 opt, value = arg[2:].split('=', 1)
133 else:
134 opt = arg[2:]
135 value = True
136 override_kwargs[opt] = value
137
138 # override kwargs with options if user is overwriting
139 for key, value in options.__dict__.iteritems():
140 if value is not None:
141 override_kwargs[key] = value
142
143 # arguments that function accepts without passed kwargs
144 f_required = list(f_args)
145 candidates = dict(kwargs)
146 candidates.update(override_kwargs)
147 for key, value in candidates.iteritems():
148 if key in f_args:
149 f_required.remove(key)
150
151 # map function arguments to parsed arguments
152 for arg in args:
153 try:
154 kw = f_required.pop(0)
155 except IndexError:
156 parser.error("Too many arguments for command %s: %s" % (command,
157 arg))
158 kwargs[kw] = arg
159
160 # apply overrides
161 kwargs.update(override_kwargs)
162
163 # configure options
164 for key, value in options.__dict__.iteritems():
165 kwargs.setdefault(key, value)
166
167 # configure logging
168 if not asbool(kwargs.pop('disable_logging', False)):
169 # filter to log =< INFO into stdout and rest to stderr
170 class SingleLevelFilter(logging.Filter):
171 def __init__(self, min=None, max=None):
172 self.min = min or 0
173 self.max = max or 100
174
175 def filter(self, record):
176 return self.min <= record.levelno <= self.max
177
178 logger = logging.getLogger()
179 h1 = logging.StreamHandler(sys.stdout)
180 f1 = SingleLevelFilter(max=logging.INFO)
181 h1.addFilter(f1)
182 h2 = logging.StreamHandler(sys.stderr)
183 f2 = SingleLevelFilter(min=logging.WARN)
184 h2.addFilter(f2)
185 logger.addHandler(h1)
186 logger.addHandler(h2)
187
188 if options.debug:
189 logger.setLevel(logging.DEBUG)
190 else:
191 logger.setLevel(logging.INFO)
192
193 log = logging.getLogger(__name__)
194
195 # check if all args are given
196 try:
197 num_defaults = len(f_defaults)
198 except TypeError:
199 num_defaults = 0
200 f_args_default = f_args[len(f_args) - num_defaults:]
201 required = list(set(f_required) - set(f_args_default))
202 if required:
203 parser.error("Not enough arguments for command %s: %s not specified" \
204 % (command, ', '.join(required)))
205
206 # handle command
207 try:
208 ret = command_func(**kwargs)
209 if ret is not None:
210 log.info(ret)
211 except (exceptions.UsageError, exceptions.KnownError), e:
212 parser.error(e.args[0])
213
214 if __name__ == "__main__":
215 main()
@@ -0,0 +1,94 b''
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import os
5 import shutil
6 import sys
7
8 from pkg_resources import resource_filename
9
10 from migrate.versioning.config import *
11 from migrate.versioning import pathed
12
13
14 class Collection(pathed.Pathed):
15 """A collection of templates of a specific type"""
16 _mask = None
17
18 def get_path(self, file):
19 return os.path.join(self.path, str(file))
20
21
22 class RepositoryCollection(Collection):
23 _mask = '%s'
24
25 class ScriptCollection(Collection):
26 _mask = '%s.py_tmpl'
27
28 class ManageCollection(Collection):
29 _mask = '%s.py_tmpl'
30
31 class SQLScriptCollection(Collection):
32 _mask = '%s.py_tmpl'
33
34 class Template(pathed.Pathed):
35 """Finds the paths/packages of various Migrate templates.
36
37 :param path: Templates are loaded from migrate package
38 if `path` is not provided.
39 """
40 pkg = 'migrate.versioning.templates'
41 _manage = 'manage.py_tmpl'
42
43 def __new__(cls, path=None):
44 if path is None:
45 path = cls._find_path(cls.pkg)
46 return super(Template, cls).__new__(cls, path)
47
48 def __init__(self, path=None):
49 if path is None:
50 path = Template._find_path(self.pkg)
51 super(Template, self).__init__(path)
52 self.repository = RepositoryCollection(os.path.join(path, 'repository'))
53 self.script = ScriptCollection(os.path.join(path, 'script'))
54 self.manage = ManageCollection(os.path.join(path, 'manage'))
55 self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script'))
56
57 @classmethod
58 def _find_path(cls, pkg):
59 """Returns absolute path to dotted python package."""
60 tmp_pkg = pkg.rsplit('.', 1)
61
62 if len(tmp_pkg) != 1:
63 return resource_filename(tmp_pkg[0], tmp_pkg[1])
64 else:
65 return resource_filename(tmp_pkg[0], '')
66
67 def _get_item(self, collection, theme=None):
68 """Locates and returns collection.
69
70 :param collection: name of collection to locate
71 :param type_: type of subfolder in collection (defaults to "_default")
72 :returns: (package, source)
73 :rtype: str, str
74 """
75 item = getattr(self, collection)
76 theme_mask = getattr(item, '_mask')
77 theme = theme_mask % (theme or 'default')
78 return item.get_path(theme)
79
80 def get_repository(self, *a, **kw):
81 """Calls self._get_item('repository', *a, **kw)"""
82 return self._get_item('repository', *a, **kw)
83
84 def get_script(self, *a, **kw):
85 """Calls self._get_item('script', *a, **kw)"""
86 return self._get_item('script', *a, **kw)
87
88 def get_sql_script(self, *a, **kw):
89 """Calls self._get_item('sql_script', *a, **kw)"""
90 return self._get_item('sql_script', *a, **kw)
91
92 def get_manage(self, *a, **kw):
93 """Calls self._get_item('manage', *a, **kw)"""
94 return self._get_item('manage', *a, **kw)
1 NO CONTENT: new file 100644
NO CONTENT: new file 100644
@@ -0,0 +1,5 b''
1 #!/usr/bin/env python
2 from migrate.versioning.shell import main
3
4 if __name__ == '__main__':
5 main(%(defaults)s)
@@ -0,0 +1,10 b''
1 #!/usr/bin/env python
2 from migrate.versioning.shell import main
3
4 {{py:
5 _vars = locals().copy()
6 del _vars['__template_name__']
7 _vars.pop('repository_name', None)
8 defaults = ", ".join(["%s='%s'" % var for var in _vars.iteritems()])
9 }}
10 main({{ defaults }})
@@ -0,0 +1,29 b''
1 #!/usr/bin/python
2 # -*- coding: utf-8 -*-
3 import sys
4
5 from sqlalchemy import engine_from_config
6 from paste.deploy.loadwsgi import ConfigLoader
7
8 from migrate.versioning.shell import main
9 from {{ locals().pop('repository_name') }}.model import migrations
10
11
12 if '-c' in sys.argv:
13 pos = sys.argv.index('-c')
14 conf_path = sys.argv[pos + 1]
15 del sys.argv[pos:pos + 2]
16 else:
17 conf_path = 'development.ini'
18
19 {{py:
20 _vars = locals().copy()
21 del _vars['__template_name__']
22 defaults = ", ".join(["%s='%s'" % var for var in _vars.iteritems()])
23 }}
24
25 conf_dict = ConfigLoader(conf_path).parser._sections['app:main']
26
27 # migrate supports passing url as an existing Engine instance (since 0.6.0)
28 # usage: migrate -c path/to/config.ini COMMANDS
29 main(url=engine_from_config(conf_dict), repository=migrations.__path__[0],{{ defaults }})
1 NO CONTENT: new file 100644
NO CONTENT: new file 100644
@@ -0,0 +1,4 b''
1 This is a database migration repository.
2
3 More information at
4 http://code.google.com/p/sqlalchemy-migrate/
1 NO CONTENT: new file 100644
NO CONTENT: new file 100644
@@ -0,0 +1,20 b''
1 [db_settings]
2 # Used to identify which repository this database is versioned under.
3 # You can use the name of your project.
4 repository_id={{ locals().pop('repository_id') }}
5
6 # The name of the database table used to track the schema version.
7 # This name shouldn't already be used by your project.
8 # If this is changed once a database is under version control, you'll need to
9 # change the table name in each database too.
10 version_table={{ locals().pop('version_table') }}
11
12 # When committing a change script, Migrate will attempt to generate the
13 # sql for all supported databases; normally, if one of them fails - probably
14 # because you don't have that database installed - it is ignored and the
15 # commit continues, perhaps ending successfully.
16 # Databases in this list MUST compile successfully during a commit, or the
17 # entire commit will fail. List the databases your application will actually
18 # be using to ensure your updates to that database work properly.
19 # This must be a list; example: ['postgres','sqlite']
20 required_dbs={{ locals().pop('required_dbs') }}
1 NO CONTENT: new file 100644
NO CONTENT: new file 100644
@@ -0,0 +1,4 b''
1 This is a database migration repository.
2
3 More information at
4 http://code.google.com/p/sqlalchemy-migrate/
1 NO CONTENT: new file 100644
NO CONTENT: new file 100644
@@ -0,0 +1,20 b''
1 [db_settings]
2 # Used to identify which repository this database is versioned under.
3 # You can use the name of your project.
4 repository_id={{ locals().pop('repository_id') }}
5
6 # The name of the database table used to track the schema version.
7 # This name shouldn't already be used by your project.
8 # If this is changed once a database is under version control, you'll need to
9 # change the table name in each database too.
10 version_table={{ locals().pop('version_table') }}
11
12 # When committing a change script, Migrate will attempt to generate the
13 # sql for all supported databases; normally, if one of them fails - probably
14 # because you don't have that database installed - it is ignored and the
15 # commit continues, perhaps ending successfully.
16 # Databases in this list MUST compile successfully during a commit, or the
17 # entire commit will fail. List the databases your application will actually
18 # be using to ensure your updates to that database work properly.
19 # This must be a list; example: ['postgres','sqlite']
20 required_dbs={{ locals().pop('required_dbs') }}
1 NO CONTENT: new file 100644
NO CONTENT: new file 100644
1 NO CONTENT: new file 100644
NO CONTENT: new file 100644
@@ -0,0 +1,11 b''
1 from sqlalchemy import *
2 from migrate import *
3
4 def upgrade(migrate_engine):
5 # Upgrade operations go here. Don't create your own engine; bind migrate_engine
6 # to your metadata
7 pass
8
9 def downgrade(migrate_engine):
10 # Operations to reverse the above upgrade go here.
11 pass
@@ -0,0 +1,11 b''
1 from sqlalchemy import *
2 from migrate import *
3
4 def upgrade(migrate_engine):
5 # Upgrade operations go here. Don't create your own engine; bind migrate_engine
6 # to your metadata
7 pass
8
9 def downgrade(migrate_engine):
10 # Operations to reverse the above upgrade go here.
11 pass
1 NO CONTENT: new file 100644
NO CONTENT: new file 100644
1 NO CONTENT: new file 100644
NO CONTENT: new file 100644
@@ -0,0 +1,179 b''
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 """.. currentmodule:: migrate.versioning.util"""
4
5 import warnings
6 import logging
7 from decorator import decorator
8 from pkg_resources import EntryPoint
9
10 from sqlalchemy import create_engine
11 from sqlalchemy.engine import Engine
12 from sqlalchemy.pool import StaticPool
13
14 from migrate import exceptions
15 from migrate.versioning.util.keyedinstance import KeyedInstance
16 from migrate.versioning.util.importpath import import_path
17
18
19 log = logging.getLogger(__name__)
20
21 def load_model(dotted_name):
22 """Import module and use module-level variable".
23
24 :param dotted_name: path to model in form of string: ``some.python.module:Class``
25
26 .. versionchanged:: 0.5.4
27
28 """
29 if isinstance(dotted_name, basestring):
30 if ':' not in dotted_name:
31 # backwards compatibility
32 warnings.warn('model should be in form of module.model:User '
33 'and not module.model.User', exceptions.MigrateDeprecationWarning)
34 dotted_name = ':'.join(dotted_name.rsplit('.', 1))
35 return EntryPoint.parse('x=%s' % dotted_name).load(False)
36 else:
37 # Assume it's already loaded.
38 return dotted_name
39
40 def asbool(obj):
41 """Do everything to use object as bool"""
42 if isinstance(obj, basestring):
43 obj = obj.strip().lower()
44 if obj in ['true', 'yes', 'on', 'y', 't', '1']:
45 return True
46 elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
47 return False
48 else:
49 raise ValueError("String is not true/false: %r" % obj)
50 if obj in (True, False):
51 return bool(obj)
52 else:
53 raise ValueError("String is not true/false: %r" % obj)
54
55 def guess_obj_type(obj):
56 """Do everything to guess object type from string
57
58 Tries to convert to `int`, `bool` and finally returns if not succeded.
59
60 .. versionadded: 0.5.4
61 """
62
63 result = None
64
65 try:
66 result = int(obj)
67 except:
68 pass
69
70 if result is None:
71 try:
72 result = asbool(obj)
73 except:
74 pass
75
76 if result is not None:
77 return result
78 else:
79 return obj
80
81 @decorator
82 def catch_known_errors(f, *a, **kw):
83 """Decorator that catches known api errors
84
85 .. versionadded: 0.5.4
86 """
87
88 try:
89 return f(*a, **kw)
90 except exceptions.PathFoundError, e:
91 raise exceptions.KnownError("The path %s already exists" % e.args[0])
92
93 def construct_engine(engine, **opts):
94 """.. versionadded:: 0.5.4
95
96 Constructs and returns SQLAlchemy engine.
97
98 Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions:
99
100 :param engine: connection string or a existing engine
101 :param engine_dict: python dictionary of options to pass to `create_engine`
102 :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`)
103 :type engine_dict: dict
104 :type engine: string or Engine instance
105 :type engine_arg_*: string
106 :returns: SQLAlchemy Engine
107
108 .. note::
109
110 keyword parameters override ``engine_dict`` values.
111
112 """
113 if isinstance(engine, Engine):
114 return engine
115 elif not isinstance(engine, basestring):
116 raise ValueError("you need to pass either an existing engine or a database uri")
117
118 # get options for create_engine
119 if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
120 kwargs = opts['engine_dict']
121 else:
122 kwargs = dict()
123
124 # DEPRECATED: handle echo the old way
125 echo = asbool(opts.get('echo', False))
126 if echo:
127 warnings.warn('echo=True parameter is deprecated, pass '
128 'engine_arg_echo=True or engine_dict={"echo": True}',
129 exceptions.MigrateDeprecationWarning)
130 kwargs['echo'] = echo
131
132 # parse keyword arguments
133 for key, value in opts.iteritems():
134 if key.startswith('engine_arg_'):
135 kwargs[key[11:]] = guess_obj_type(value)
136
137 log.debug('Constructing engine')
138 # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs)
139 # seems like 0.5.x branch does not work with engine.dispose and staticpool
140 return create_engine(engine, **kwargs)
141
142 @decorator
143 def with_engine(f, *a, **kw):
144 """Decorator for :mod:`migrate.versioning.api` functions
145 to safely close resources after function usage.
146
147 Passes engine parameters to :func:`construct_engine` and
148 resulting parameter is available as kw['engine'].
149
150 Engine is disposed after wrapped function is executed.
151
152 .. versionadded: 0.6.0
153 """
154 url = a[0]
155 engine = construct_engine(url, **kw)
156
157 try:
158 kw['engine'] = engine
159 return f(*a, **kw)
160 finally:
161 if isinstance(engine, Engine):
162 log.debug('Disposing SQLAlchemy engine %s', engine)
163 engine.dispose()
164
165
166 class Memoize:
167 """Memoize(fn) - an instance which acts like fn but memoizes its arguments
168 Will only work on functions with non-mutable arguments
169
170 ActiveState Code 52201
171 """
172 def __init__(self, fn):
173 self.fn = fn
174 self.memo = {}
175
176 def __call__(self, *args):
177 if not self.memo.has_key(args):
178 self.memo[args] = self.fn(*args)
179 return self.memo[args]
@@ -0,0 +1,16 b''
1 import os
2 import sys
3
4 def import_path(fullpath):
5 """ Import a file with full path specification. Allows one to
6 import from anywhere, something __import__ does not do.
7 """
8 # http://zephyrfalcon.org/weblog/arch_d7_2002_08_31.html
9 path, filename = os.path.split(fullpath)
10 filename, ext = os.path.splitext(filename)
11 sys.path.append(path)
12 module = __import__(filename)
13 reload(module) # Might be out of date during tests
14 del sys.path[-1]
15 return module
16
@@ -0,0 +1,36 b''
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 class KeyedInstance(object):
5 """A class whose instances have a unique identifier of some sort
6 No two instances with the same unique ID should exist - if we try to create
7 a second instance, the first should be returned.
8 """
9
10 _instances = dict()
11
12 def __new__(cls, *p, **k):
13 instances = cls._instances
14 clskey = str(cls)
15 if clskey not in instances:
16 instances[clskey] = dict()
17 instances = instances[clskey]
18
19 key = cls._key(*p, **k)
20 if key not in instances:
21 instances[key] = super(KeyedInstance, cls).__new__(cls)
22 return instances[key]
23
24 @classmethod
25 def _key(cls, *p, **k):
26 """Given a unique identifier, return a dictionary key
27 This should be overridden by child classes, to specify which parameters
28 should determine an object's uniqueness
29 """
30 raise NotImplementedError()
31
32 @classmethod
33 def clear(cls):
34 # Allow cls.clear() as well as uniqueInstance.clear(cls)
35 if str(cls) in cls._instances:
36 del cls._instances[str(cls)]
@@ -0,0 +1,215 b''
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import os
5 import re
6 import shutil
7 import logging
8
9 from migrate import exceptions
10 from migrate.versioning import pathed, script
11
12
13 log = logging.getLogger(__name__)
14
15 class VerNum(object):
16 """A version number that behaves like a string and int at the same time"""
17
18 _instances = dict()
19
20 def __new__(cls, value):
21 val = str(value)
22 if val not in cls._instances:
23 cls._instances[val] = super(VerNum, cls).__new__(cls)
24 ret = cls._instances[val]
25 return ret
26
27 def __init__(self,value):
28 self.value = str(int(value))
29 if self < 0:
30 raise ValueError("Version number cannot be negative")
31
32 def __add__(self, value):
33 ret = int(self) + int(value)
34 return VerNum(ret)
35
36 def __sub__(self, value):
37 return self + (int(value) * -1)
38
39 def __cmp__(self, value):
40 return int(self) - int(value)
41
42 def __repr__(self):
43 return "<VerNum(%s)>" % self.value
44
45 def __str__(self):
46 return str(self.value)
47
48 def __int__(self):
49 return int(self.value)
50
51
52 class Collection(pathed.Pathed):
53 """A collection of versioning scripts in a repository"""
54
55 FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
56
57 def __init__(self, path):
58 """Collect current version scripts in repository
59 and store them in self.versions
60 """
61 super(Collection, self).__init__(path)
62
63 # Create temporary list of files, allowing skipped version numbers.
64 files = os.listdir(path)
65 if '1' in files:
66 # deprecation
67 raise Exception('It looks like you have a repository in the old '
68 'format (with directories for each version). '
69 'Please convert repository before proceeding.')
70
71 tempVersions = dict()
72 for filename in files:
73 match = self.FILENAME_WITH_VERSION.match(filename)
74 if match:
75 num = int(match.group(1))
76 tempVersions.setdefault(num, []).append(filename)
77 else:
78 pass # Must be a helper file or something, let's ignore it.
79
80 # Create the versions member where the keys
81 # are VerNum's and the values are Version's.
82 self.versions = dict()
83 for num, files in tempVersions.items():
84 self.versions[VerNum(num)] = Version(num, path, files)
85
86 @property
87 def latest(self):
88 """:returns: Latest version in Collection"""
89 return max([VerNum(0)] + self.versions.keys())
90
91 def create_new_python_version(self, description, **k):
92 """Create Python files for new version"""
93 ver = self.latest + 1
94 extra = str_to_filename(description)
95
96 if extra:
97 if extra == '_':
98 extra = ''
99 elif not extra.startswith('_'):
100 extra = '_%s' % extra
101
102 filename = '%03d%s.py' % (ver, extra)
103 filepath = self._version_path(filename)
104
105 script.PythonScript.create(filepath, **k)
106 self.versions[ver] = Version(ver, self.path, [filename])
107
108 def create_new_sql_version(self, database, **k):
109 """Create SQL files for new version"""
110 ver = self.latest + 1
111 self.versions[ver] = Version(ver, self.path, [])
112
113 # Create new files.
114 for op in ('upgrade', 'downgrade'):
115 filename = '%03d_%s_%s.sql' % (ver, database, op)
116 filepath = self._version_path(filename)
117 script.SqlScript.create(filepath, **k)
118 self.versions[ver].add_script(filepath)
119
120 def version(self, vernum=None):
121 """Returns latest Version if vernum is not given.
122 Otherwise, returns wanted version"""
123 if vernum is None:
124 vernum = self.latest
125 return self.versions[VerNum(vernum)]
126
127 @classmethod
128 def clear(cls):
129 super(Collection, cls).clear()
130
131 def _version_path(self, ver):
132 """Returns path of file in versions repository"""
133 return os.path.join(self.path, str(ver))
134
135
136 class Version(object):
137 """A single version in a collection
138 :param vernum: Version Number
139 :param path: Path to script files
140 :param filelist: List of scripts
141 :type vernum: int, VerNum
142 :type path: string
143 :type filelist: list
144 """
145
146 def __init__(self, vernum, path, filelist):
147 self.version = VerNum(vernum)
148
149 # Collect scripts in this folder
150 self.sql = dict()
151 self.python = None
152
153 for script in filelist:
154 self.add_script(os.path.join(path, script))
155
156 def script(self, database=None, operation=None):
157 """Returns SQL or Python Script"""
158 for db in (database, 'default'):
159 # Try to return a .sql script first
160 try:
161 return self.sql[db][operation]
162 except KeyError:
163 continue # No .sql script exists
164
165 # TODO: maybe add force Python parameter?
166 ret = self.python
167
168 assert ret is not None, \
169 "There is no script for %d version" % self.version
170 return ret
171
172 def add_script(self, path):
173 """Add script to Collection/Version"""
174 if path.endswith(Extensions.py):
175 self._add_script_py(path)
176 elif path.endswith(Extensions.sql):
177 self._add_script_sql(path)
178
179 SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
180
181 def _add_script_sql(self, path):
182 basename = os.path.basename(path)
183 match = self.SQL_FILENAME.match(basename)
184
185 if match:
186 version, dbms, op = match.group(1), match.group(2), match.group(3)
187 else:
188 raise exceptions.ScriptError(
189 "Invalid SQL script name %s " % basename + \
190 "(needs to be ###_database_operation.sql)")
191
192 # File the script into a dictionary
193 self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
194
195 def _add_script_py(self, path):
196 if self.python is not None:
197 raise exceptions.ScriptError('You can only have one Python script '
198 'per version, but you have: %s and %s' % (self.python, path))
199 self.python = script.PythonScript(path)
200
201
202 class Extensions:
203 """A namespace for file extensions"""
204 py = 'py'
205 sql = 'sql'
206
207 def str_to_filename(s):
208 """Replaces spaces, (double and single) quotes
209 and double underscores to underscores
210 """
211
212 s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
213 while '__' in s:
214 s = s.replace('__', '_')
215 return s
@@ -0,0 +1,238 b''
1 from migrate import *
2
3 #==============================================================================
4 # DB INITIAL MODEL
5 #==============================================================================
6 import logging
7 import datetime
8
9 from sqlalchemy import *
10 from sqlalchemy.exc import DatabaseError
11 from sqlalchemy.orm import relation, backref, class_mapper
12 from sqlalchemy.orm.session import Session
13
14 from rhodecode.model.meta import Base
15
16 log = logging.getLogger(__name__)
17
18 class BaseModel(object):
19
20 @classmethod
21 def _get_keys(cls):
22 """return column names for this model """
23 return class_mapper(cls).c.keys()
24
25 def get_dict(self):
26 """return dict with keys and values corresponding
27 to this model data """
28
29 d = {}
30 for k in self._get_keys():
31 d[k] = getattr(self, k)
32 return d
33
34 def get_appstruct(self):
35 """return list with keys and values tupples corresponding
36 to this model data """
37
38 l = []
39 for k in self._get_keys():
40 l.append((k, getattr(self, k),))
41 return l
42
43 def populate_obj(self, populate_dict):
44 """populate model with data from given populate_dict"""
45
46 for k in self._get_keys():
47 if k in populate_dict:
48 setattr(self, k, populate_dict[k])
49
50 class RhodeCodeSettings(Base, BaseModel):
51 __tablename__ = 'rhodecode_settings'
52 __table_args__ = (UniqueConstraint('app_settings_name'), {'useexisting':True})
53 app_settings_id = Column("app_settings_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
54 app_settings_name = Column("app_settings_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
55 app_settings_value = Column("app_settings_value", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
56
57 def __init__(self, k, v):
58 self.app_settings_name = k
59 self.app_settings_value = v
60
61 def __repr__(self):
62 return "<RhodeCodeSetting('%s:%s')>" % (self.app_settings_name,
63 self.app_settings_value)
64
65 class RhodeCodeUi(Base, BaseModel):
66 __tablename__ = 'rhodecode_ui'
67 __table_args__ = {'useexisting':True}
68 ui_id = Column("ui_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
69 ui_section = Column("ui_section", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
70 ui_key = Column("ui_key", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
71 ui_value = Column("ui_value", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
72 ui_active = Column("ui_active", Boolean(), nullable=True, unique=None, default=True)
73
74
75 class User(Base, BaseModel):
76 __tablename__ = 'users'
77 __table_args__ = (UniqueConstraint('username'), UniqueConstraint('email'), {'useexisting':True})
78 user_id = Column("user_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
79 username = Column("username", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
80 password = Column("password", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
81 active = Column("active", Boolean(), nullable=True, unique=None, default=None)
82 admin = Column("admin", Boolean(), nullable=True, unique=None, default=False)
83 name = Column("name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
84 lastname = Column("lastname", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
85 email = Column("email", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
86 last_login = Column("last_login", DateTime(timezone=False), nullable=True, unique=None, default=None)
87 is_ldap = Column("is_ldap", Boolean(), nullable=False, unique=None, default=False)
88
89 user_log = relation('UserLog', cascade='all')
90 user_perms = relation('UserToPerm', primaryjoin="User.user_id==UserToPerm.user_id", cascade='all')
91
92 repositories = relation('Repository')
93 user_followers = relation('UserFollowing', primaryjoin='UserFollowing.follows_user_id==User.user_id', cascade='all')
94
95 @property
96 def full_contact(self):
97 return '%s %s <%s>' % (self.name, self.lastname, self.email)
98
99 def __repr__(self):
100 return "<User('id:%s:%s')>" % (self.user_id, self.username)
101
102 def update_lastlogin(self):
103 """Update user lastlogin"""
104
105 try:
106 session = Session.object_session(self)
107 self.last_login = datetime.datetime.now()
108 session.add(self)
109 session.commit()
110 log.debug('updated user %s lastlogin', self.username)
111 except (DatabaseError,):
112 session.rollback()
113
114
115 class UserLog(Base, BaseModel):
116 __tablename__ = 'user_logs'
117 __table_args__ = {'useexisting':True}
118 user_log_id = Column("user_log_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
119 user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
120 repository_id = Column("repository_id", Integer(length=None, convert_unicode=False, assert_unicode=None), ForeignKey(u'repositories.repo_id'), nullable=False, unique=None, default=None)
121 repository_name = Column("repository_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
122 user_ip = Column("user_ip", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
123 action = Column("action", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
124 action_date = Column("action_date", DateTime(timezone=False), nullable=True, unique=None, default=None)
125
126 user = relation('User')
127 repository = relation('Repository')
128
129 class Repository(Base, BaseModel):
130 __tablename__ = 'repositories'
131 __table_args__ = (UniqueConstraint('repo_name'), {'useexisting':True},)
132 repo_id = Column("repo_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
133 repo_name = Column("repo_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=False, unique=True, default=None)
134 repo_type = Column("repo_type", String(length=None, convert_unicode=False, assert_unicode=None), nullable=False, unique=False, default=None)
135 user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=False, default=None)
136 private = Column("private", Boolean(), nullable=True, unique=None, default=None)
137 enable_statistics = Column("statistics", Boolean(), nullable=True, unique=None, default=True)
138 description = Column("description", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
139 fork_id = Column("fork_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=True, unique=False, default=None)
140
141 user = relation('User')
142 fork = relation('Repository', remote_side=repo_id)
143 repo_to_perm = relation('RepoToPerm', cascade='all')
144 stats = relation('Statistics', cascade='all', uselist=False)
145
146 repo_followers = relation('UserFollowing', primaryjoin='UserFollowing.follows_repo_id==Repository.repo_id', cascade='all')
147
148
149 def __repr__(self):
150 return "<Repository('%s:%s')>" % (self.repo_id, self.repo_name)
151
152 class Permission(Base, BaseModel):
153 __tablename__ = 'permissions'
154 __table_args__ = {'useexisting':True}
155 permission_id = Column("permission_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
156 permission_name = Column("permission_name", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
157 permission_longname = Column("permission_longname", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
158
159 def __repr__(self):
160 return "<Permission('%s:%s')>" % (self.permission_id, self.permission_name)
161
162 class RepoToPerm(Base, BaseModel):
163 __tablename__ = 'repo_to_perm'
164 __table_args__ = (UniqueConstraint('user_id', 'repository_id'), {'useexisting':True})
165 repo_to_perm_id = Column("repo_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
166 user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
167 permission_id = Column("permission_id", Integer(), ForeignKey(u'permissions.permission_id'), nullable=False, unique=None, default=None)
168 repository_id = Column("repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=False, unique=None, default=None)
169
170 user = relation('User')
171 permission = relation('Permission')
172 repository = relation('Repository')
173
174 class UserToPerm(Base, BaseModel):
175 __tablename__ = 'user_to_perm'
176 __table_args__ = (UniqueConstraint('user_id', 'permission_id'), {'useexisting':True})
177 user_to_perm_id = Column("user_to_perm_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
178 user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
179 permission_id = Column("permission_id", Integer(), ForeignKey(u'permissions.permission_id'), nullable=False, unique=None, default=None)
180
181 user = relation('User')
182 permission = relation('Permission')
183
184 class Statistics(Base, BaseModel):
185 __tablename__ = 'statistics'
186 __table_args__ = (UniqueConstraint('repository_id'), {'useexisting':True})
187 stat_id = Column("stat_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
188 repository_id = Column("repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=False, unique=True, default=None)
189 stat_on_revision = Column("stat_on_revision", Integer(), nullable=False)
190 commit_activity = Column("commit_activity", LargeBinary(), nullable=False)#JSON data
191 commit_activity_combined = Column("commit_activity_combined", LargeBinary(), nullable=False)#JSON data
192 languages = Column("languages", LargeBinary(), nullable=False)#JSON data
193
194 repository = relation('Repository', single_parent=True)
195
196 class UserFollowing(Base, BaseModel):
197 __tablename__ = 'user_followings'
198 __table_args__ = (UniqueConstraint('user_id', 'follows_repository_id'),
199 UniqueConstraint('user_id', 'follows_user_id')
200 , {'useexisting':True})
201
202 user_following_id = Column("user_following_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
203 user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
204 follows_repo_id = Column("follows_repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=True, unique=None, default=None)
205 follows_user_id = Column("follows_user_id", Integer(), ForeignKey(u'users.user_id'), nullable=True, unique=None, default=None)
206
207 user = relation('User', primaryjoin='User.user_id==UserFollowing.user_id')
208
209 follows_user = relation('User', primaryjoin='User.user_id==UserFollowing.follows_user_id')
210 follows_repository = relation('Repository')
211
212
213 class CacheInvalidation(Base, BaseModel):
214 __tablename__ = 'cache_invalidation'
215 __table_args__ = (UniqueConstraint('cache_key'), {'useexisting':True})
216 cache_id = Column("cache_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
217 cache_key = Column("cache_key", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
218 cache_args = Column("cache_args", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
219 cache_active = Column("cache_active", Boolean(), nullable=True, unique=None, default=False)
220
221
222 def __init__(self, cache_key, cache_args=''):
223 self.cache_key = cache_key
224 self.cache_args = cache_args
225 self.cache_active = False
226
227 def __repr__(self):
228 return "<CacheInvalidation('%s:%s')>" % (self.cache_id, self.cache_key)
229
230
231 def upgrade(migrate_engine):
232 # Upgrade operations go here. Don't create your own engine; bind migrate_engine
233 # to your metadata
234 Base.metadata.create_all(bind=migrate_engine, checkfirst=False)
235
236 def downgrade(migrate_engine):
237 # Operations to reverse the above upgrade go here.
238 Base.metadata.drop_all(bind=migrate_engine, checkfirst=False)
@@ -0,0 +1,118 b''
1 from sqlalchemy import *
2 from sqlalchemy.orm import relation
3
4 from migrate import *
5 from migrate.changeset import *
6 from rhodecode.model.meta import Base, BaseModel
7
8 def upgrade(migrate_engine):
9 """ Upgrade operations go here.
10 Don't create your own engine; bind migrate_engine to your metadata
11 """
12
13 #==========================================================================
14 # Upgrade of `users` table
15 #==========================================================================
16 tblname = 'users'
17 tbl = Table(tblname, MetaData(bind=migrate_engine), autoload=True,
18 autoload_with=migrate_engine)
19
20 #ADD is_ldap column
21 is_ldap = Column("is_ldap", Boolean(), nullable=False,
22 unique=None, default=False)
23 is_ldap.create(tbl)
24
25
26 #==========================================================================
27 # Upgrade of `user_logs` table
28 #==========================================================================
29
30 tblname = 'users'
31 tbl = Table(tblname, MetaData(bind=migrate_engine), autoload=True,
32 autoload_with=migrate_engine)
33
34 #ADD revision column
35 revision = Column('revision', TEXT(length=None, convert_unicode=False,
36 assert_unicode=None),
37 nullable=True, unique=None, default=None)
38 revision.create(tbl)
39
40
41
42 #==========================================================================
43 # Upgrade of `repositories` table
44 #==========================================================================
45 tblname = 'users'
46 tbl = Table(tblname, MetaData(bind=migrate_engine), autoload=True,
47 autoload_with=migrate_engine)
48
49 #ADD repo_type column
50 repo_type = Column("repo_type", String(length=None, convert_unicode=False,
51 assert_unicode=None),
52 nullable=False, unique=False, default=None)
53 repo_type.create(tbl)
54
55
56 #ADD statistics column
57 enable_statistics = Column("statistics", Boolean(), nullable=True,
58 unique=None, default=True)
59 enable_statistics.create(tbl)
60
61
62
63 #==========================================================================
64 # Add table `user_followings`
65 #==========================================================================
66 tblname = 'user_followings'
67 class UserFollowing(Base, BaseModel):
68 __tablename__ = 'user_followings'
69 __table_args__ = (UniqueConstraint('user_id', 'follows_repository_id'),
70 UniqueConstraint('user_id', 'follows_user_id')
71 , {'useexisting':True})
72
73 user_following_id = Column("user_following_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
74 user_id = Column("user_id", Integer(), ForeignKey(u'users.user_id'), nullable=False, unique=None, default=None)
75 follows_repo_id = Column("follows_repository_id", Integer(), ForeignKey(u'repositories.repo_id'), nullable=True, unique=None, default=None)
76 follows_user_id = Column("follows_user_id", Integer(), ForeignKey(u'users.user_id'), nullable=True, unique=None, default=None)
77
78 user = relation('User', primaryjoin='User.user_id==UserFollowing.user_id')
79
80 follows_user = relation('User', primaryjoin='User.user_id==UserFollowing.follows_user_id')
81 follows_repository = relation('Repository')
82
83 Base.metadata.tables[tblname].create(migrate_engine)
84
85 #==========================================================================
86 # Add table `cache_invalidation`
87 #==========================================================================
88 class CacheInvalidation(Base, BaseModel):
89 __tablename__ = 'cache_invalidation'
90 __table_args__ = (UniqueConstraint('cache_key'), {'useexisting':True})
91 cache_id = Column("cache_id", Integer(), nullable=False, unique=True, default=None, primary_key=True)
92 cache_key = Column("cache_key", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
93 cache_args = Column("cache_args", String(length=None, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None)
94 cache_active = Column("cache_active", Boolean(), nullable=True, unique=None, default=False)
95
96
97 def __init__(self, cache_key, cache_args=''):
98 self.cache_key = cache_key
99 self.cache_args = cache_args
100 self.cache_active = False
101
102 def __repr__(self):
103 return "<CacheInvalidation('%s:%s')>" % (self.cache_id, self.cache_key)
104
105 Base.metadata.tables[tblname].create(migrate_engine)
106
107 return
108
109
110
111
112
113
114 def downgrade(migrate_engine):
115 meta = MetaData()
116 meta.bind = migrate_engine
117
118
@@ -0,0 +1,26 b''
1 # -*- coding: utf-8 -*-
2 """
3 rhodecode.lib.dbmigrate.versions.__init__
4 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5
6 Package containing new versions of database models
7
8 :created_on: Dec 11, 2010
9 :author: marcink
10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 :license: GPLv3, see COPYING for more details.
12 """
13 # This program is free software; you can redistribute it and/or
14 # modify it under the terms of the GNU General Public License
15 # as published by the Free Software Foundation; version 2
16 # of the License or (at your opinion) any later version of the license.
17 #
18 # This program is distributed in the hope that it will be useful,
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 # GNU General Public License for more details.
22 #
23 # You should have received a copy of the GNU General Public License
24 # along with this program; if not, write to the Free Software
25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
26 # MA 02110-1301, USA.
@@ -1,628 +1,601 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 package.rhodecode.lib.utils
3 rhodecode.lib.utils
4 ~~~~~~~~~~~~~~
4 ~~~~~~~~~~~~~~~~~~~
5
5
6 Utilities library for RhodeCode
6 Utilities library for RhodeCode
7
7
8 :created_on: Apr 18, 2010
8 :created_on: Apr 18, 2010
9 :author: marcink
9 :author: marcink
10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 :license: GPLv3, see COPYING for more details.
11 :license: GPLv3, see COPYING for more details.
12 """
12 """
13 # This program is free software; you can redistribute it and/or
13 # This program is free software; you can redistribute it and/or
14 # modify it under the terms of the GNU General Public License
14 # modify it under the terms of the GNU General Public License
15 # as published by the Free Software Foundation; version 2
15 # as published by the Free Software Foundation; version 2
16 # of the License or (at your opinion) any later version of the license.
16 # of the License or (at your opinion) any later version of the license.
17 #
17 #
18 # This program is distributed in the hope that it will be useful,
18 # This program is distributed in the hope that it will be useful,
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 # GNU General Public License for more details.
21 # GNU General Public License for more details.
22 #
22 #
23 # You should have received a copy of the GNU General Public License
23 # You should have received a copy of the GNU General Public License
24 # along with this program; if not, write to the Free Software
24 # along with this program; if not, write to the Free Software
25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
26 # MA 02110-1301, USA.
26 # MA 02110-1301, USA.
27
27
28 import os
28 import os
29 import logging
29 import logging
30 import datetime
30 import datetime
31 import traceback
31 import traceback
32
32
33 from UserDict import DictMixin
33 from UserDict import DictMixin
34
34
35 from mercurial import ui, config, hg
35 from mercurial import ui, config, hg
36 from mercurial.error import RepoError
36 from mercurial.error import RepoError
37
37
38 import paste
38 import paste
39 import beaker
39 import beaker
40 from paste.script.command import Command, BadCommand
40 from paste.script.command import Command, BadCommand
41
41
42 from vcs.backends.base import BaseChangeset
42 from vcs.backends.base import BaseChangeset
43 from vcs.utils.lazy import LazyProperty
43 from vcs.utils.lazy import LazyProperty
44
44
45 from rhodecode.model import meta
45 from rhodecode.model import meta
46 from rhodecode.model.caching_query import FromCache
46 from rhodecode.model.caching_query import FromCache
47 from rhodecode.model.db import Repository, User, RhodeCodeUi, UserLog
47 from rhodecode.model.db import Repository, User, RhodeCodeUi, UserLog
48 from rhodecode.model.repo import RepoModel
48 from rhodecode.model.repo import RepoModel
49 from rhodecode.model.user import UserModel
49 from rhodecode.model.user import UserModel
50
50
51 log = logging.getLogger(__name__)
51 log = logging.getLogger(__name__)
52
52
53
53
54 def get_repo_slug(request):
54 def get_repo_slug(request):
55 return request.environ['pylons.routes_dict'].get('repo_name')
55 return request.environ['pylons.routes_dict'].get('repo_name')
56
56
57 def action_logger(user, action, repo, ipaddr='', sa=None):
57 def action_logger(user, action, repo, ipaddr='', sa=None):
58 """
58 """
59 Action logger for various actions made by users
59 Action logger for various actions made by users
60
60
61 :param user: user that made this action, can be a unique username string or
61 :param user: user that made this action, can be a unique username string or
62 object containing user_id attribute
62 object containing user_id attribute
63 :param action: action to log, should be on of predefined unique actions for
63 :param action: action to log, should be on of predefined unique actions for
64 easy translations
64 easy translations
65 :param repo: string name of repository or object containing repo_id,
65 :param repo: string name of repository or object containing repo_id,
66 that action was made on
66 that action was made on
67 :param ipaddr: optional ip address from what the action was made
67 :param ipaddr: optional ip address from what the action was made
68 :param sa: optional sqlalchemy session
68 :param sa: optional sqlalchemy session
69
69
70 """
70 """
71
71
72 if not sa:
72 if not sa:
73 sa = meta.Session()
73 sa = meta.Session()
74
74
75 try:
75 try:
76 um = UserModel()
76 um = UserModel()
77 if hasattr(user, 'user_id'):
77 if hasattr(user, 'user_id'):
78 user_obj = user
78 user_obj = user
79 elif isinstance(user, basestring):
79 elif isinstance(user, basestring):
80 user_obj = um.get_by_username(user, cache=False)
80 user_obj = um.get_by_username(user, cache=False)
81 else:
81 else:
82 raise Exception('You have to provide user object or username')
82 raise Exception('You have to provide user object or username')
83
83
84
84
85 rm = RepoModel()
85 rm = RepoModel()
86 if hasattr(repo, 'repo_id'):
86 if hasattr(repo, 'repo_id'):
87 repo_obj = rm.get(repo.repo_id, cache=False)
87 repo_obj = rm.get(repo.repo_id, cache=False)
88 repo_name = repo_obj.repo_name
88 repo_name = repo_obj.repo_name
89 elif isinstance(repo, basestring):
89 elif isinstance(repo, basestring):
90 repo_name = repo.lstrip('/')
90 repo_name = repo.lstrip('/')
91 repo_obj = rm.get_by_repo_name(repo_name, cache=False)
91 repo_obj = rm.get_by_repo_name(repo_name, cache=False)
92 else:
92 else:
93 raise Exception('You have to provide repository to action logger')
93 raise Exception('You have to provide repository to action logger')
94
94
95
95
96 user_log = UserLog()
96 user_log = UserLog()
97 user_log.user_id = user_obj.user_id
97 user_log.user_id = user_obj.user_id
98 user_log.action = action
98 user_log.action = action
99
99
100 user_log.repository_id = repo_obj.repo_id
100 user_log.repository_id = repo_obj.repo_id
101 user_log.repository_name = repo_name
101 user_log.repository_name = repo_name
102
102
103 user_log.action_date = datetime.datetime.now()
103 user_log.action_date = datetime.datetime.now()
104 user_log.user_ip = ipaddr
104 user_log.user_ip = ipaddr
105 sa.add(user_log)
105 sa.add(user_log)
106 sa.commit()
106 sa.commit()
107
107
108 log.info('Adding user %s, action %s on %s', user_obj, action, repo)
108 log.info('Adding user %s, action %s on %s', user_obj, action, repo)
109 except:
109 except:
110 log.error(traceback.format_exc())
110 log.error(traceback.format_exc())
111 sa.rollback()
111 sa.rollback()
112
112
113 def get_repos(path, recursive=False, initial=False):
113 def get_repos(path, recursive=False, initial=False):
114 """
114 """
115 Scans given path for repos and return (name,(type,path)) tuple
115 Scans given path for repos and return (name,(type,path)) tuple
116 :param prefix:
116 :param prefix:
117 :param path:
117 :param path:
118 :param recursive:
118 :param recursive:
119 :param initial:
119 :param initial:
120 """
120 """
121 from vcs.utils.helpers import get_scm
121 from vcs.utils.helpers import get_scm
122 from vcs.exceptions import VCSError
122 from vcs.exceptions import VCSError
123
123
124 try:
124 try:
125 scm = get_scm(path)
125 scm = get_scm(path)
126 except:
126 except:
127 pass
127 pass
128 else:
128 else:
129 raise Exception('The given path %s should not be a repository got %s',
129 raise Exception('The given path %s should not be a repository got %s',
130 path, scm)
130 path, scm)
131
131
132 for dirpath in os.listdir(path):
132 for dirpath in os.listdir(path):
133 try:
133 try:
134 yield dirpath, get_scm(os.path.join(path, dirpath))
134 yield dirpath, get_scm(os.path.join(path, dirpath))
135 except VCSError:
135 except VCSError:
136 pass
136 pass
137
137
138 def check_repo_fast(repo_name, base_path):
138 def check_repo_fast(repo_name, base_path):
139 """
139 """
140 Check given path for existance of directory
140 Check given path for existance of directory
141 :param repo_name:
141 :param repo_name:
142 :param base_path:
142 :param base_path:
143
143
144 :return False: if this directory is present
144 :return False: if this directory is present
145 """
145 """
146 if os.path.isdir(os.path.join(base_path, repo_name)):return False
146 if os.path.isdir(os.path.join(base_path, repo_name)):return False
147 return True
147 return True
148
148
149 def check_repo(repo_name, base_path, verify=True):
149 def check_repo(repo_name, base_path, verify=True):
150
150
151 repo_path = os.path.join(base_path, repo_name)
151 repo_path = os.path.join(base_path, repo_name)
152
152
153 try:
153 try:
154 if not check_repo_fast(repo_name, base_path):
154 if not check_repo_fast(repo_name, base_path):
155 return False
155 return False
156 r = hg.repository(ui.ui(), repo_path)
156 r = hg.repository(ui.ui(), repo_path)
157 if verify:
157 if verify:
158 hg.verify(r)
158 hg.verify(r)
159 #here we hnow that repo exists it was verified
159 #here we hnow that repo exists it was verified
160 log.info('%s repo is already created', repo_name)
160 log.info('%s repo is already created', repo_name)
161 return False
161 return False
162 except RepoError:
162 except RepoError:
163 #it means that there is no valid repo there...
163 #it means that there is no valid repo there...
164 log.info('%s repo is free for creation', repo_name)
164 log.info('%s repo is free for creation', repo_name)
165 return True
165 return True
166
166
167 def ask_ok(prompt, retries=4, complaint='Yes or no, please!'):
167 def ask_ok(prompt, retries=4, complaint='Yes or no, please!'):
168 while True:
168 while True:
169 ok = raw_input(prompt)
169 ok = raw_input(prompt)
170 if ok in ('y', 'ye', 'yes'): return True
170 if ok in ('y', 'ye', 'yes'): return True
171 if ok in ('n', 'no', 'nop', 'nope'): return False
171 if ok in ('n', 'no', 'nop', 'nope'): return False
172 retries = retries - 1
172 retries = retries - 1
173 if retries < 0: raise IOError
173 if retries < 0: raise IOError
174 print complaint
174 print complaint
175
175
176 #propagated from mercurial documentation
176 #propagated from mercurial documentation
177 ui_sections = ['alias', 'auth',
177 ui_sections = ['alias', 'auth',
178 'decode/encode', 'defaults',
178 'decode/encode', 'defaults',
179 'diff', 'email',
179 'diff', 'email',
180 'extensions', 'format',
180 'extensions', 'format',
181 'merge-patterns', 'merge-tools',
181 'merge-patterns', 'merge-tools',
182 'hooks', 'http_proxy',
182 'hooks', 'http_proxy',
183 'smtp', 'patch',
183 'smtp', 'patch',
184 'paths', 'profiling',
184 'paths', 'profiling',
185 'server', 'trusted',
185 'server', 'trusted',
186 'ui', 'web', ]
186 'ui', 'web', ]
187
187
188 def make_ui(read_from='file', path=None, checkpaths=True):
188 def make_ui(read_from='file', path=None, checkpaths=True):
189 """
189 """
190 A function that will read python rc files or database
190 A function that will read python rc files or database
191 and make an mercurial ui object from read options
191 and make an mercurial ui object from read options
192
192
193 :param path: path to mercurial config file
193 :param path: path to mercurial config file
194 :param checkpaths: check the path
194 :param checkpaths: check the path
195 :param read_from: read from 'file' or 'db'
195 :param read_from: read from 'file' or 'db'
196 """
196 """
197
197
198 baseui = ui.ui()
198 baseui = ui.ui()
199
199
200 #clean the baseui object
200 #clean the baseui object
201 baseui._ocfg = config.config()
201 baseui._ocfg = config.config()
202 baseui._ucfg = config.config()
202 baseui._ucfg = config.config()
203 baseui._tcfg = config.config()
203 baseui._tcfg = config.config()
204
204
205 if read_from == 'file':
205 if read_from == 'file':
206 if not os.path.isfile(path):
206 if not os.path.isfile(path):
207 log.warning('Unable to read config file %s' % path)
207 log.warning('Unable to read config file %s' % path)
208 return False
208 return False
209 log.debug('reading hgrc from %s', path)
209 log.debug('reading hgrc from %s', path)
210 cfg = config.config()
210 cfg = config.config()
211 cfg.read(path)
211 cfg.read(path)
212 for section in ui_sections:
212 for section in ui_sections:
213 for k, v in cfg.items(section):
213 for k, v in cfg.items(section):
214 log.debug('settings ui from file[%s]%s:%s', section, k, v)
214 log.debug('settings ui from file[%s]%s:%s', section, k, v)
215 baseui.setconfig(section, k, v)
215 baseui.setconfig(section, k, v)
216
216
217
217
218 elif read_from == 'db':
218 elif read_from == 'db':
219 sa = meta.Session()
219 sa = meta.Session()
220 ret = sa.query(RhodeCodeUi)\
220 ret = sa.query(RhodeCodeUi)\
221 .options(FromCache("sql_cache_short",
221 .options(FromCache("sql_cache_short",
222 "get_hg_ui_settings")).all()
222 "get_hg_ui_settings")).all()
223
223
224 hg_ui = ret
224 hg_ui = ret
225 for ui_ in hg_ui:
225 for ui_ in hg_ui:
226 if ui_.ui_active:
226 if ui_.ui_active:
227 log.debug('settings ui from db[%s]%s:%s', ui_.ui_section,
227 log.debug('settings ui from db[%s]%s:%s', ui_.ui_section,
228 ui_.ui_key, ui_.ui_value)
228 ui_.ui_key, ui_.ui_value)
229 baseui.setconfig(ui_.ui_section, ui_.ui_key, ui_.ui_value)
229 baseui.setconfig(ui_.ui_section, ui_.ui_key, ui_.ui_value)
230
230
231 meta.Session.remove()
231 meta.Session.remove()
232 return baseui
232 return baseui
233
233
234
234
235 def set_rhodecode_config(config):
235 def set_rhodecode_config(config):
236 """
236 """
237 Updates pylons config with new settings from database
237 Updates pylons config with new settings from database
238 :param config:
238 :param config:
239 """
239 """
240 from rhodecode.model.settings import SettingsModel
240 from rhodecode.model.settings import SettingsModel
241 hgsettings = SettingsModel().get_app_settings()
241 hgsettings = SettingsModel().get_app_settings()
242
242
243 for k, v in hgsettings.items():
243 for k, v in hgsettings.items():
244 config[k] = v
244 config[k] = v
245
245
246 def invalidate_cache(cache_key, *args):
246 def invalidate_cache(cache_key, *args):
247 """
247 """
248 Puts cache invalidation task into db for
248 Puts cache invalidation task into db for
249 further global cache invalidation
249 further global cache invalidation
250 """
250 """
251 from rhodecode.model.scm import ScmModel
251 from rhodecode.model.scm import ScmModel
252
252
253 if cache_key.startswith('get_repo_cached_'):
253 if cache_key.startswith('get_repo_cached_'):
254 name = cache_key.split('get_repo_cached_')[-1]
254 name = cache_key.split('get_repo_cached_')[-1]
255 ScmModel().mark_for_invalidation(name)
255 ScmModel().mark_for_invalidation(name)
256
256
257 class EmptyChangeset(BaseChangeset):
257 class EmptyChangeset(BaseChangeset):
258 """
258 """
259 An dummy empty changeset. It's possible to pass hash when creating
259 An dummy empty changeset. It's possible to pass hash when creating
260 an EmptyChangeset
260 an EmptyChangeset
261 """
261 """
262
262
263 def __init__(self, cs='0' * 40):
263 def __init__(self, cs='0' * 40):
264 self._empty_cs = cs
264 self._empty_cs = cs
265 self.revision = -1
265 self.revision = -1
266 self.message = ''
266 self.message = ''
267 self.author = ''
267 self.author = ''
268 self.date = ''
268 self.date = ''
269
269
270 @LazyProperty
270 @LazyProperty
271 def raw_id(self):
271 def raw_id(self):
272 """
272 """
273 Returns raw string identifying this changeset, useful for web
273 Returns raw string identifying this changeset, useful for web
274 representation.
274 representation.
275 """
275 """
276 return self._empty_cs
276 return self._empty_cs
277
277
278 @LazyProperty
278 @LazyProperty
279 def short_id(self):
279 def short_id(self):
280 return self.raw_id[:12]
280 return self.raw_id[:12]
281
281
282 def get_file_changeset(self, path):
282 def get_file_changeset(self, path):
283 return self
283 return self
284
284
285 def get_file_content(self, path):
285 def get_file_content(self, path):
286 return u''
286 return u''
287
287
288 def get_file_size(self, path):
288 def get_file_size(self, path):
289 return 0
289 return 0
290
290
291 def repo2db_mapper(initial_repo_list, remove_obsolete=False):
291 def repo2db_mapper(initial_repo_list, remove_obsolete=False):
292 """
292 """
293 maps all found repositories into db
293 maps all found repositories into db
294 """
294 """
295
295
296 sa = meta.Session()
296 sa = meta.Session()
297 rm = RepoModel()
297 rm = RepoModel()
298 user = sa.query(User).filter(User.admin == True).first()
298 user = sa.query(User).filter(User.admin == True).first()
299
299
300 for name, repo in initial_repo_list.items():
300 for name, repo in initial_repo_list.items():
301 if not rm.get_by_repo_name(name, cache=False):
301 if not rm.get_by_repo_name(name, cache=False):
302 log.info('repository %s not found creating default', name)
302 log.info('repository %s not found creating default', name)
303
303
304 form_data = {
304 form_data = {
305 'repo_name':name,
305 'repo_name':name,
306 'repo_type':repo.alias,
306 'repo_type':repo.alias,
307 'description':repo.description \
307 'description':repo.description \
308 if repo.description != 'unknown' else \
308 if repo.description != 'unknown' else \
309 '%s repository' % name,
309 '%s repository' % name,
310 'private':False
310 'private':False
311 }
311 }
312 rm.create(form_data, user, just_db=True)
312 rm.create(form_data, user, just_db=True)
313
313
314 if remove_obsolete:
314 if remove_obsolete:
315 #remove from database those repositories that are not in the filesystem
315 #remove from database those repositories that are not in the filesystem
316 for repo in sa.query(Repository).all():
316 for repo in sa.query(Repository).all():
317 if repo.repo_name not in initial_repo_list.keys():
317 if repo.repo_name not in initial_repo_list.keys():
318 sa.delete(repo)
318 sa.delete(repo)
319 sa.commit()
319 sa.commit()
320
320
321 class OrderedDict(dict, DictMixin):
321 class OrderedDict(dict, DictMixin):
322
322
323 def __init__(self, *args, **kwds):
323 def __init__(self, *args, **kwds):
324 if len(args) > 1:
324 if len(args) > 1:
325 raise TypeError('expected at most 1 arguments, got %d' % len(args))
325 raise TypeError('expected at most 1 arguments, got %d' % len(args))
326 try:
326 try:
327 self.__end
327 self.__end
328 except AttributeError:
328 except AttributeError:
329 self.clear()
329 self.clear()
330 self.update(*args, **kwds)
330 self.update(*args, **kwds)
331
331
332 def clear(self):
332 def clear(self):
333 self.__end = end = []
333 self.__end = end = []
334 end += [None, end, end] # sentinel node for doubly linked list
334 end += [None, end, end] # sentinel node for doubly linked list
335 self.__map = {} # key --> [key, prev, next]
335 self.__map = {} # key --> [key, prev, next]
336 dict.clear(self)
336 dict.clear(self)
337
337
338 def __setitem__(self, key, value):
338 def __setitem__(self, key, value):
339 if key not in self:
339 if key not in self:
340 end = self.__end
340 end = self.__end
341 curr = end[1]
341 curr = end[1]
342 curr[2] = end[1] = self.__map[key] = [key, curr, end]
342 curr[2] = end[1] = self.__map[key] = [key, curr, end]
343 dict.__setitem__(self, key, value)
343 dict.__setitem__(self, key, value)
344
344
345 def __delitem__(self, key):
345 def __delitem__(self, key):
346 dict.__delitem__(self, key)
346 dict.__delitem__(self, key)
347 key, prev, next = self.__map.pop(key)
347 key, prev, next = self.__map.pop(key)
348 prev[2] = next
348 prev[2] = next
349 next[1] = prev
349 next[1] = prev
350
350
351 def __iter__(self):
351 def __iter__(self):
352 end = self.__end
352 end = self.__end
353 curr = end[2]
353 curr = end[2]
354 while curr is not end:
354 while curr is not end:
355 yield curr[0]
355 yield curr[0]
356 curr = curr[2]
356 curr = curr[2]
357
357
358 def __reversed__(self):
358 def __reversed__(self):
359 end = self.__end
359 end = self.__end
360 curr = end[1]
360 curr = end[1]
361 while curr is not end:
361 while curr is not end:
362 yield curr[0]
362 yield curr[0]
363 curr = curr[1]
363 curr = curr[1]
364
364
365 def popitem(self, last=True):
365 def popitem(self, last=True):
366 if not self:
366 if not self:
367 raise KeyError('dictionary is empty')
367 raise KeyError('dictionary is empty')
368 if last:
368 if last:
369 key = reversed(self).next()
369 key = reversed(self).next()
370 else:
370 else:
371 key = iter(self).next()
371 key = iter(self).next()
372 value = self.pop(key)
372 value = self.pop(key)
373 return key, value
373 return key, value
374
374
375 def __reduce__(self):
375 def __reduce__(self):
376 items = [[k, self[k]] for k in self]
376 items = [[k, self[k]] for k in self]
377 tmp = self.__map, self.__end
377 tmp = self.__map, self.__end
378 del self.__map, self.__end
378 del self.__map, self.__end
379 inst_dict = vars(self).copy()
379 inst_dict = vars(self).copy()
380 self.__map, self.__end = tmp
380 self.__map, self.__end = tmp
381 if inst_dict:
381 if inst_dict:
382 return (self.__class__, (items,), inst_dict)
382 return (self.__class__, (items,), inst_dict)
383 return self.__class__, (items,)
383 return self.__class__, (items,)
384
384
385 def keys(self):
385 def keys(self):
386 return list(self)
386 return list(self)
387
387
388 setdefault = DictMixin.setdefault
388 setdefault = DictMixin.setdefault
389 update = DictMixin.update
389 update = DictMixin.update
390 pop = DictMixin.pop
390 pop = DictMixin.pop
391 values = DictMixin.values
391 values = DictMixin.values
392 items = DictMixin.items
392 items = DictMixin.items
393 iterkeys = DictMixin.iterkeys
393 iterkeys = DictMixin.iterkeys
394 itervalues = DictMixin.itervalues
394 itervalues = DictMixin.itervalues
395 iteritems = DictMixin.iteritems
395 iteritems = DictMixin.iteritems
396
396
397 def __repr__(self):
397 def __repr__(self):
398 if not self:
398 if not self:
399 return '%s()' % (self.__class__.__name__,)
399 return '%s()' % (self.__class__.__name__,)
400 return '%s(%r)' % (self.__class__.__name__, self.items())
400 return '%s(%r)' % (self.__class__.__name__, self.items())
401
401
402 def copy(self):
402 def copy(self):
403 return self.__class__(self)
403 return self.__class__(self)
404
404
405 @classmethod
405 @classmethod
406 def fromkeys(cls, iterable, value=None):
406 def fromkeys(cls, iterable, value=None):
407 d = cls()
407 d = cls()
408 for key in iterable:
408 for key in iterable:
409 d[key] = value
409 d[key] = value
410 return d
410 return d
411
411
412 def __eq__(self, other):
412 def __eq__(self, other):
413 if isinstance(other, OrderedDict):
413 if isinstance(other, OrderedDict):
414 return len(self) == len(other) and self.items() == other.items()
414 return len(self) == len(other) and self.items() == other.items()
415 return dict.__eq__(self, other)
415 return dict.__eq__(self, other)
416
416
417 def __ne__(self, other):
417 def __ne__(self, other):
418 return not self == other
418 return not self == other
419
419
420
420
421 #set cache regions for beaker so celery can utilise it
421 #set cache regions for beaker so celery can utilise it
422 def add_cache(settings):
422 def add_cache(settings):
423 cache_settings = {'regions':None}
423 cache_settings = {'regions':None}
424 for key in settings.keys():
424 for key in settings.keys():
425 for prefix in ['beaker.cache.', 'cache.']:
425 for prefix in ['beaker.cache.', 'cache.']:
426 if key.startswith(prefix):
426 if key.startswith(prefix):
427 name = key.split(prefix)[1].strip()
427 name = key.split(prefix)[1].strip()
428 cache_settings[name] = settings[key].strip()
428 cache_settings[name] = settings[key].strip()
429 if cache_settings['regions']:
429 if cache_settings['regions']:
430 for region in cache_settings['regions'].split(','):
430 for region in cache_settings['regions'].split(','):
431 region = region.strip()
431 region = region.strip()
432 region_settings = {}
432 region_settings = {}
433 for key, value in cache_settings.items():
433 for key, value in cache_settings.items():
434 if key.startswith(region):
434 if key.startswith(region):
435 region_settings[key.split('.')[1]] = value
435 region_settings[key.split('.')[1]] = value
436 region_settings['expire'] = int(region_settings.get('expire',
436 region_settings['expire'] = int(region_settings.get('expire',
437 60))
437 60))
438 region_settings.setdefault('lock_dir',
438 region_settings.setdefault('lock_dir',
439 cache_settings.get('lock_dir'))
439 cache_settings.get('lock_dir'))
440 if 'type' not in region_settings:
440 if 'type' not in region_settings:
441 region_settings['type'] = cache_settings.get('type',
441 region_settings['type'] = cache_settings.get('type',
442 'memory')
442 'memory')
443 beaker.cache.cache_regions[region] = region_settings
443 beaker.cache.cache_regions[region] = region_settings
444
444
445 def get_current_revision():
445 def get_current_revision():
446 """
446 """
447 Returns tuple of (number, id) from repository containing this package
447 Returns tuple of (number, id) from repository containing this package
448 or None if repository could not be found.
448 or None if repository could not be found.
449 """
449 """
450 try:
450 try:
451 from vcs import get_repo
451 from vcs import get_repo
452 from vcs.utils.helpers import get_scm
452 from vcs.utils.helpers import get_scm
453 from vcs.exceptions import RepositoryError, VCSError
453 from vcs.exceptions import RepositoryError, VCSError
454 repopath = os.path.join(os.path.dirname(__file__), '..', '..')
454 repopath = os.path.join(os.path.dirname(__file__), '..', '..')
455 scm = get_scm(repopath)[0]
455 scm = get_scm(repopath)[0]
456 repo = get_repo(path=repopath, alias=scm)
456 repo = get_repo(path=repopath, alias=scm)
457 tip = repo.get_changeset()
457 tip = repo.get_changeset()
458 return (tip.revision, tip.short_id)
458 return (tip.revision, tip.short_id)
459 except (ImportError, RepositoryError, VCSError), err:
459 except (ImportError, RepositoryError, VCSError), err:
460 logging.debug("Cannot retrieve rhodecode's revision. Original error "
460 logging.debug("Cannot retrieve rhodecode's revision. Original error "
461 "was: %s" % err)
461 "was: %s" % err)
462 return None
462 return None
463
463
464 #===============================================================================
464 #===============================================================================
465 # TEST FUNCTIONS AND CREATORS
465 # TEST FUNCTIONS AND CREATORS
466 #===============================================================================
466 #===============================================================================
467 def create_test_index(repo_location, full_index):
467 def create_test_index(repo_location, full_index):
468 """Makes default test index
468 """Makes default test index
469 :param repo_location:
469 :param repo_location:
470 :param full_index:
470 :param full_index:
471 """
471 """
472 from rhodecode.lib.indexers.daemon import WhooshIndexingDaemon
472 from rhodecode.lib.indexers.daemon import WhooshIndexingDaemon
473 from rhodecode.lib.pidlock import DaemonLock, LockHeld
473 from rhodecode.lib.pidlock import DaemonLock, LockHeld
474 import shutil
474 import shutil
475
475
476 index_location = os.path.join(repo_location, 'index')
476 index_location = os.path.join(repo_location, 'index')
477 if os.path.exists(index_location):
477 if os.path.exists(index_location):
478 shutil.rmtree(index_location)
478 shutil.rmtree(index_location)
479
479
480 try:
480 try:
481 l = DaemonLock()
481 l = DaemonLock()
482 WhooshIndexingDaemon(index_location=index_location,
482 WhooshIndexingDaemon(index_location=index_location,
483 repo_location=repo_location)\
483 repo_location=repo_location)\
484 .run(full_index=full_index)
484 .run(full_index=full_index)
485 l.release()
485 l.release()
486 except LockHeld:
486 except LockHeld:
487 pass
487 pass
488
488
489 def create_test_env(repos_test_path, config):
489 def create_test_env(repos_test_path, config):
490 """Makes a fresh database and
490 """Makes a fresh database and
491 install test repository into tmp dir
491 install test repository into tmp dir
492 """
492 """
493 from rhodecode.lib.db_manage import DbManage
493 from rhodecode.lib.db_manage import DbManage
494 from rhodecode.tests import HG_REPO, GIT_REPO, NEW_HG_REPO, NEW_GIT_REPO, \
494 from rhodecode.tests import HG_REPO, GIT_REPO, NEW_HG_REPO, NEW_GIT_REPO, \
495 HG_FORK, GIT_FORK, TESTS_TMP_PATH
495 HG_FORK, GIT_FORK, TESTS_TMP_PATH
496 import tarfile
496 import tarfile
497 import shutil
497 import shutil
498 from os.path import dirname as dn, join as jn, abspath
498 from os.path import dirname as dn, join as jn, abspath
499
499
500 log = logging.getLogger('TestEnvCreator')
500 log = logging.getLogger('TestEnvCreator')
501 # create logger
501 # create logger
502 log.setLevel(logging.DEBUG)
502 log.setLevel(logging.DEBUG)
503 log.propagate = True
503 log.propagate = True
504 # create console handler and set level to debug
504 # create console handler and set level to debug
505 ch = logging.StreamHandler()
505 ch = logging.StreamHandler()
506 ch.setLevel(logging.DEBUG)
506 ch.setLevel(logging.DEBUG)
507
507
508 # create formatter
508 # create formatter
509 formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
509 formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
510
510
511 # add formatter to ch
511 # add formatter to ch
512 ch.setFormatter(formatter)
512 ch.setFormatter(formatter)
513
513
514 # add ch to logger
514 # add ch to logger
515 log.addHandler(ch)
515 log.addHandler(ch)
516
516
517 #PART ONE create db
517 #PART ONE create db
518 dbconf = config['sqlalchemy.db1.url']
518 dbconf = config['sqlalchemy.db1.url']
519 log.debug('making test db %s', dbconf)
519 log.debug('making test db %s', dbconf)
520
520
521 dbmanage = DbManage(log_sql=True, dbconf=dbconf, root=config['here'],
521 dbmanage = DbManage(log_sql=True, dbconf=dbconf, root=config['here'],
522 tests=True)
522 tests=True)
523 dbmanage.create_tables(override=True)
523 dbmanage.create_tables(override=True)
524 dbmanage.config_prompt(repos_test_path)
524 dbmanage.config_prompt(repos_test_path)
525 dbmanage.create_default_user()
525 dbmanage.create_default_user()
526 dbmanage.admin_prompt()
526 dbmanage.admin_prompt()
527 dbmanage.create_permissions()
527 dbmanage.create_permissions()
528 dbmanage.populate_default_permissions()
528 dbmanage.populate_default_permissions()
529
529
530 #PART TWO make test repo
530 #PART TWO make test repo
531 log.debug('making test vcs repositories')
531 log.debug('making test vcs repositories')
532
532
533 #remove old one from previos tests
533 #remove old one from previos tests
534 for r in [HG_REPO, GIT_REPO, NEW_HG_REPO, NEW_GIT_REPO, HG_FORK, GIT_FORK]:
534 for r in [HG_REPO, GIT_REPO, NEW_HG_REPO, NEW_GIT_REPO, HG_FORK, GIT_FORK]:
535
535
536 if os.path.isdir(jn(TESTS_TMP_PATH, r)):
536 if os.path.isdir(jn(TESTS_TMP_PATH, r)):
537 log.debug('removing %s', r)
537 log.debug('removing %s', r)
538 shutil.rmtree(jn(TESTS_TMP_PATH, r))
538 shutil.rmtree(jn(TESTS_TMP_PATH, r))
539
539
540 #CREATE DEFAULT HG REPOSITORY
540 #CREATE DEFAULT HG REPOSITORY
541 cur_dir = dn(dn(abspath(__file__)))
541 cur_dir = dn(dn(abspath(__file__)))
542 tar = tarfile.open(jn(cur_dir, 'tests', "vcs_test_hg.tar.gz"))
542 tar = tarfile.open(jn(cur_dir, 'tests', "vcs_test_hg.tar.gz"))
543 tar.extractall(jn(TESTS_TMP_PATH, HG_REPO))
543 tar.extractall(jn(TESTS_TMP_PATH, HG_REPO))
544 tar.close()
544 tar.close()
545
545
546
546
547 #==============================================================================
547 #==============================================================================
548 # PASTER COMMANDS
548 # PASTER COMMANDS
549 #==============================================================================
549 #==============================================================================
550
550
551 class BasePasterCommand(Command):
551 class BasePasterCommand(Command):
552 """
552 """
553 Abstract Base Class for paster commands.
553 Abstract Base Class for paster commands.
554
554
555 The celery commands are somewhat aggressive about loading
555 The celery commands are somewhat aggressive about loading
556 celery.conf, and since our module sets the `CELERY_LOADER`
556 celery.conf, and since our module sets the `CELERY_LOADER`
557 environment variable to our loader, we have to bootstrap a bit and
557 environment variable to our loader, we have to bootstrap a bit and
558 make sure we've had a chance to load the pylons config off of the
558 make sure we've had a chance to load the pylons config off of the
559 command line, otherwise everything fails.
559 command line, otherwise everything fails.
560 """
560 """
561 min_args = 1
561 min_args = 1
562 min_args_error = "Please provide a paster config file as an argument."
562 min_args_error = "Please provide a paster config file as an argument."
563 takes_config_file = 1
563 takes_config_file = 1
564 requires_config_file = True
564 requires_config_file = True
565
565
566 def run(self, args):
566 def run(self, args):
567 """
567 """
568 Overrides Command.run
568 Overrides Command.run
569
569
570 Checks for a config file argument and loads it.
570 Checks for a config file argument and loads it.
571 """
571 """
572 if len(args) < self.min_args:
572 if len(args) < self.min_args:
573 raise BadCommand(
573 raise BadCommand(
574 self.min_args_error % {'min_args': self.min_args,
574 self.min_args_error % {'min_args': self.min_args,
575 'actual_args': len(args)})
575 'actual_args': len(args)})
576
576
577 # Decrement because we're going to lob off the first argument.
577 # Decrement because we're going to lob off the first argument.
578 # @@ This is hacky
578 # @@ This is hacky
579 self.min_args -= 1
579 self.min_args -= 1
580 self.bootstrap_config(args[0])
580 self.bootstrap_config(args[0])
581 self.update_parser()
581 self.update_parser()
582 return super(BasePasterCommand, self).run(args[1:])
582 return super(BasePasterCommand, self).run(args[1:])
583
583
584 def update_parser(self):
584 def update_parser(self):
585 """
585 """
586 Abstract method. Allows for the class's parser to be updated
586 Abstract method. Allows for the class's parser to be updated
587 before the superclass's `run` method is called. Necessary to
587 before the superclass's `run` method is called. Necessary to
588 allow options/arguments to be passed through to the underlying
588 allow options/arguments to be passed through to the underlying
589 celery command.
589 celery command.
590 """
590 """
591 raise NotImplementedError("Abstract Method.")
591 raise NotImplementedError("Abstract Method.")
592
592
593 def bootstrap_config(self, conf):
593 def bootstrap_config(self, conf):
594 """
594 """
595 Loads the pylons configuration.
595 Loads the pylons configuration.
596 """
596 """
597 from pylons import config as pylonsconfig
597 from pylons import config as pylonsconfig
598
598
599 path_to_ini_file = os.path.realpath(conf)
599 path_to_ini_file = os.path.realpath(conf)
600 conf = paste.deploy.appconfig('config:' + path_to_ini_file)
600 conf = paste.deploy.appconfig('config:' + path_to_ini_file)
601 pylonsconfig.init_app(conf.global_conf, conf.local_conf)
601 pylonsconfig.init_app(conf.global_conf, conf.local_conf)
602
603
604
605 class UpgradeDb(BasePasterCommand):
606 """Command used for paster to upgrade our database to newer version
607 """
608
609 max_args = 1
610 min_args = 1
611
612 usage = "CONFIG_FILE"
613 summary = "Upgrades current db to newer version given configuration file"
614 group_name = "RhodeCode"
615
616 parser = Command.standard_parser(verbose=True)
617
618 def command(self):
619 from pylons import config
620 raise NotImplementedError('Not implemented yet')
621
622
623 def update_parser(self):
624 self.parser.add_option('--sql',
625 action='store_true',
626 dest='just_sql',
627 help="Prints upgrade sql for further investigation",
628 default=False)
@@ -1,103 +1,103 b''
1 import sys
1 import sys
2 py_version = sys.version_info
2 py_version = sys.version_info
3
3
4 from rhodecode import get_version
4 from rhodecode import get_version
5
5
6 requirements = [
6 requirements = [
7 "Pylons==1.0.0",
7 "Pylons==1.0.0",
8 "SQLAlchemy==0.6.5",
8 "SQLAlchemy==0.6.5",
9 "Mako==0.3.6",
9 "Mako==0.3.6",
10 "vcs==0.1.10",
10 "vcs==0.1.10",
11 "pygments==1.3.1",
11 "pygments==1.3.1",
12 "mercurial==1.7.2",
12 "mercurial==1.7.2",
13 "whoosh==1.3.4",
13 "whoosh==1.3.4",
14 "celery==2.1.4",
14 "celery==2.1.4",
15 "py-bcrypt",
15 "py-bcrypt",
16 "babel",
16 "babel",
17 ]
17 ]
18
18
19 classifiers = ['Development Status :: 4 - Beta',
19 classifiers = ['Development Status :: 4 - Beta',
20 'Environment :: Web Environment',
20 'Environment :: Web Environment',
21 'Framework :: Pylons',
21 'Framework :: Pylons',
22 'Intended Audience :: Developers',
22 'Intended Audience :: Developers',
23 'License :: OSI Approved :: BSD License',
23 'License :: OSI Approved :: BSD License',
24 'Operating System :: OS Independent',
24 'Operating System :: OS Independent',
25 'Programming Language :: Python', ]
25 'Programming Language :: Python', ]
26
26
27 if sys.version_info < (2, 6):
27 if sys.version_info < (2, 6):
28 requirements.append("simplejson")
28 requirements.append("simplejson")
29 requirements.append("pysqlite")
29 requirements.append("pysqlite")
30
30
31 #additional files from project that goes somewhere in the filesystem
31 #additional files from project that goes somewhere in the filesystem
32 #relative to sys.prefix
32 #relative to sys.prefix
33 data_files = []
33 data_files = []
34
34
35 #additional files that goes into package itself
35 #additional files that goes into package itself
36 package_data = {'rhodecode': ['i18n/*/LC_MESSAGES/*.mo', ], }
36 package_data = {'rhodecode': ['i18n/*/LC_MESSAGES/*.mo', ], }
37
37
38 description = ('Mercurial repository browser/management with '
38 description = ('Mercurial repository browser/management with '
39 'build in push/pull server and full text search')
39 'build in push/pull server and full text search')
40 #long description
40 #long description
41 try:
41 try:
42 readme_file = 'README.rst'
42 readme_file = 'README.rst'
43 changelog_file = 'docs/changelog.rst'
43 changelog_file = 'docs/changelog.rst'
44 long_description = open(readme_file).read() + '/n/n' + \
44 long_description = open(readme_file).read() + '/n/n' + \
45 open(changelog_file).read()
45 open(changelog_file).read()
46
46
47 except IOError, err:
47 except IOError, err:
48 sys.stderr.write("[WARNING] Cannot find file specified as "
48 sys.stderr.write("[WARNING] Cannot find file specified as "
49 "long_description (%s)\n or changelog (%s) skipping that file" \
49 "long_description (%s)\n or changelog (%s) skipping that file" \
50 % (readme_file, changelog_file))
50 % (readme_file, changelog_file))
51 long_description = description
51 long_description = description
52
52
53
53
54 try:
54 try:
55 from setuptools import setup, find_packages
55 from setuptools import setup, find_packages
56 except ImportError:
56 except ImportError:
57 from ez_setup import use_setuptools
57 from ez_setup import use_setuptools
58 use_setuptools()
58 use_setuptools()
59 from setuptools import setup, find_packages
59 from setuptools import setup, find_packages
60 #packages
60 #packages
61 packages = find_packages(exclude=['ez_setup'])
61 packages = find_packages(exclude=['ez_setup'])
62
62
63 setup(
63 setup(
64 name='RhodeCode',
64 name='RhodeCode',
65 version=get_version(),
65 version=get_version(),
66 description=description,
66 description=description,
67 long_description=long_description,
67 long_description=long_description,
68 keywords='rhodiumcode mercurial web hgwebdir gitweb git replacement serving hgweb rhodecode',
68 keywords='rhodiumcode mercurial web hgwebdir gitweb git replacement serving hgweb rhodecode',
69 license='BSD',
69 license='BSD',
70 author='Marcin Kuzminski',
70 author='Marcin Kuzminski',
71 author_email='marcin@python-works.com',
71 author_email='marcin@python-works.com',
72 url='http://hg.python-works.com',
72 url='http://hg.python-works.com',
73 install_requires=requirements,
73 install_requires=requirements,
74 classifiers=classifiers,
74 classifiers=classifiers,
75 setup_requires=["PasteScript>=1.6.3"],
75 setup_requires=["PasteScript>=1.6.3"],
76 data_files=data_files,
76 data_files=data_files,
77 packages=packages,
77 packages=packages,
78 include_package_data=True,
78 include_package_data=True,
79 test_suite='nose.collector',
79 test_suite='nose.collector',
80 package_data=package_data,
80 package_data=package_data,
81 message_extractors={'rhodecode': [
81 message_extractors={'rhodecode': [
82 ('**.py', 'python', None),
82 ('**.py', 'python', None),
83 ('templates/**.mako', 'mako', {'input_encoding': 'utf-8'}),
83 ('templates/**.mako', 'mako', {'input_encoding': 'utf-8'}),
84 ('public/**', 'ignore', None)]},
84 ('public/**', 'ignore', None)]},
85 zip_safe=False,
85 zip_safe=False,
86 paster_plugins=['PasteScript', 'Pylons'],
86 paster_plugins=['PasteScript', 'Pylons'],
87 entry_points="""
87 entry_points="""
88 [paste.app_factory]
88 [paste.app_factory]
89 main = rhodecode.config.middleware:make_app
89 main = rhodecode.config.middleware:make_app
90
90
91 [paste.app_install]
91 [paste.app_install]
92 main = pylons.util:PylonsInstaller
92 main = pylons.util:PylonsInstaller
93
93
94 [paste.global_paster_command]
94 [paste.global_paster_command]
95 make-index = rhodecode.lib.indexers:MakeIndex
95 make-index = rhodecode.lib.indexers:MakeIndex
96 upgrade-db = rhodecode.lib.utils:UpgradeDb
96 upgrade-db = rhodecode.lib.dbmigrate:UpgradeDb
97 celeryd=rhodecode.lib.celerypylons.commands:CeleryDaemonCommand
97 celeryd=rhodecode.lib.celerypylons.commands:CeleryDaemonCommand
98 celerybeat=rhodecode.lib.celerypylons.commands:CeleryBeatCommand
98 celerybeat=rhodecode.lib.celerypylons.commands:CeleryBeatCommand
99 camqadm=rhodecode.lib.celerypylons.commands:CAMQPAdminCommand
99 camqadm=rhodecode.lib.celerypylons.commands:CAMQPAdminCommand
100 celeryev=rhodecode.lib.celerypylons.commands:CeleryEventCommand
100 celeryev=rhodecode.lib.celerypylons.commands:CeleryEventCommand
101
101
102 """,
102 """,
103 )
103 )
General Comments 0
You need to be logged in to leave comments. Login now