##// END OF EJS Templates
fixed imports on migrate, added getting current version from database
marcink -
r835:08d2dcd7 beta
parent child Browse files
Show More
@@ -0,0 +1,20 b''
1 [db_settings]
2 # Used to identify which repository this database is versioned under.
3 # You can use the name of your project.
4 repository_id=rhodecode_db_migrations
5
6 # The name of the database table used to track the schema version.
7 # This name shouldn't already be used by your project.
8 # If this is changed once a database is under version control, you'll need to
9 # change the table name in each database too.
10 version_table=db_migrate_version
11
12 # When committing a change script, Migrate will attempt to generate the
13 # sql for all supported databases; normally, if one of them fails - probably
14 # because you don't have that database installed - it is ignored and the
15 # commit continues, perhaps ending successfully.
16 # Databases in this list MUST compile successfully during a commit, or the
17 # entire commit will fail. List the databases your application will actually
18 # be using to ensure your updates to that database work properly.
19 # This must be a list; example: ['postgres','sqlite']
20 required_dbs=['sqlite']
@@ -1,328 +1,332 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
3 # database management for RhodeCode
4 # Copyright (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
5 #
1 # -*- coding: utf-8 -*-
2 """
3 rhodecode.lib.db_manage
4 ~~~~~~~~~~~~~~~~~~~~~~~
5
6 Database creation, and setup module for RhodeCode
7
8 :created_on: Apr 10, 2010
9 :author: marcink
10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 :license: GPLv3, see COPYING for more details.
12 """
6 13 # This program is free software; you can redistribute it and/or
7 14 # modify it under the terms of the GNU General Public License
8 15 # as published by the Free Software Foundation; version 2
9 16 # of the License or (at your opinion) any later version of the license.
10 17 #
11 18 # This program is distributed in the hope that it will be useful,
12 19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 21 # GNU General Public License for more details.
15 22 #
16 23 # You should have received a copy of the GNU General Public License
17 24 # along with this program; if not, write to the Free Software
18 25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
19 26 # MA 02110-1301, USA.
20 27
21 """
22 Created on April 10, 2010
23 database management and creation for RhodeCode
24 @author: marcink
25 """
26
27 from os.path import dirname as dn, join as jn
28 28 import os
29 29 import sys
30 30 import uuid
31 import logging
32 from os.path import dirname as dn, join as jn
33
34 from rhodecode import __dbversion__
35 from rhodecode.model.db import
36 from rhodecode.model import meta
31 37
32 38 from rhodecode.lib.auth import get_crypt_password
33 39 from rhodecode.lib.utils import ask_ok
34 40 from rhodecode.model import init_model
35 41 from rhodecode.model.db import User, Permission, RhodeCodeUi, RhodeCodeSettings, \
36 UserToPerm
37 from rhodecode.model import meta
42 UserToPerm, DbMigrateVersion
43
38 44 from sqlalchemy.engine import create_engine
39 import logging
45
40 46
41 47 log = logging.getLogger(__name__)
42 48
43 49 class DbManage(object):
44 50 def __init__(self, log_sql, dbconf, root, tests=False):
45 51 self.dbname = dbconf.split('/')[-1]
46 52 self.tests = tests
47 53 self.root = root
48 54 self.dburi = dbconf
49 55 engine = create_engine(self.dburi, echo=log_sql)
50 56 init_model(engine)
51 57 self.sa = meta.Session()
52 58 self.db_exists = False
53 59
54 60 def check_for_db(self, override):
55 61 db_path = jn(self.root, self.dbname)
56 62 if self.dburi.startswith('sqlite'):
57 63 log.info('checking for existing db in %s', db_path)
58 64 if os.path.isfile(db_path):
59 65
60 66 self.db_exists = True
61 67 if not override:
62 68 raise Exception('database already exists')
63 69
64 70 def create_tables(self, override=False):
65 71 """
66 72 Create a auth database
67 73 """
68 74 self.check_for_db(override)
69 75 if self.db_exists:
70 76 log.info("database exist and it's going to be destroyed")
71 77 if self.tests:
72 78 destroy = True
73 79 else:
74 80 destroy = ask_ok('Are you sure to destroy old database ? [y/n]')
75 81 if not destroy:
76 82 sys.exit()
77 83 if self.db_exists and destroy:
78 84 os.remove(jn(self.root, self.dbname))
79 85 checkfirst = not override
80 86 meta.Base.metadata.create_all(checkfirst=checkfirst)
81 87 log.info('Created tables for %s', self.dbname)
82 88
83 89
84 90
85 91 def set_db_version(self):
86 from rhodecode import __dbversion__
87 from rhodecode.model.db import DbMigrateVersion
88 92 try:
89 93 ver = DbMigrateVersion()
90 94 ver.version = __dbversion__
91 95 ver.repository_id = 'rhodecode_db_migrations'
92 96 ver.repository_path = 'versions'
93 97 self.sa.add(ver)
94 98 self.sa.commit()
95 99 except:
96 100 self.sa.rollback()
97 101 raise
98 102 log.info('db version set to: %s', __dbversion__)
99 103
100 104 def admin_prompt(self, second=False):
101 105 if not self.tests:
102 106 import getpass
103 107
104 108
105 109 def get_password():
106 110 password = getpass.getpass('Specify admin password (min 6 chars):')
107 111 confirm = getpass.getpass('Confirm password:')
108 112
109 113 if password != confirm:
110 114 log.error('passwords mismatch')
111 115 return False
112 116 if len(password) < 6:
113 117 log.error('password is to short use at least 6 characters')
114 118 return False
115 119
116 120 return password
117 121
118 122 username = raw_input('Specify admin username:')
119 123
120 124 password = get_password()
121 125 if not password:
122 126 #second try
123 127 password = get_password()
124 128 if not password:
125 129 sys.exit()
126 130
127 131 email = raw_input('Specify admin email:')
128 132 self.create_user(username, password, email, True)
129 133 else:
130 134 log.info('creating admin and regular test users')
131 135 self.create_user('test_admin', 'test12', 'test_admin@mail.com', True)
132 136 self.create_user('test_regular', 'test12', 'test_regular@mail.com', False)
133 137 self.create_user('test_regular2', 'test12', 'test_regular2@mail.com', False)
134 138
135 139
136 140
137 141 def config_prompt(self, test_repo_path=''):
138 142 log.info('Setting up repositories config')
139 143
140 144 if not self.tests and not test_repo_path:
141 145 path = raw_input('Specify valid full path to your repositories'
142 146 ' you can change this later in application settings:')
143 147 else:
144 148 path = test_repo_path
145 149
146 150 if not os.path.isdir(path):
147 151 log.error('You entered wrong path: %s', path)
148 152 sys.exit()
149 153
150 154 hooks1 = RhodeCodeUi()
151 155 hooks1.ui_section = 'hooks'
152 156 hooks1.ui_key = 'changegroup.update'
153 157 hooks1.ui_value = 'hg update >&2'
154 158 hooks1.ui_active = False
155 159
156 160 hooks2 = RhodeCodeUi()
157 161 hooks2.ui_section = 'hooks'
158 162 hooks2.ui_key = 'changegroup.repo_size'
159 163 hooks2.ui_value = 'python:rhodecode.lib.hooks.repo_size'
160 164
161 165 hooks3 = RhodeCodeUi()
162 166 hooks3.ui_section = 'hooks'
163 167 hooks3.ui_key = 'pretxnchangegroup.push_logger'
164 168 hooks3.ui_value = 'python:rhodecode.lib.hooks.log_push_action'
165 169
166 170 hooks4 = RhodeCodeUi()
167 171 hooks4.ui_section = 'hooks'
168 172 hooks4.ui_key = 'preoutgoing.pull_logger'
169 173 hooks4.ui_value = 'python:rhodecode.lib.hooks.log_pull_action'
170 174
171 175 #for mercurial 1.7 set backward comapatibility with format
172 176
173 177 dotencode_disable = RhodeCodeUi()
174 178 dotencode_disable.ui_section = 'format'
175 179 dotencode_disable.ui_key = 'dotencode'
176 180 dotencode_disable.ui_section = 'false'
177 181
178 182
179 183 web1 = RhodeCodeUi()
180 184 web1.ui_section = 'web'
181 185 web1.ui_key = 'push_ssl'
182 186 web1.ui_value = 'false'
183 187
184 188 web2 = RhodeCodeUi()
185 189 web2.ui_section = 'web'
186 190 web2.ui_key = 'allow_archive'
187 191 web2.ui_value = 'gz zip bz2'
188 192
189 193 web3 = RhodeCodeUi()
190 194 web3.ui_section = 'web'
191 195 web3.ui_key = 'allow_push'
192 196 web3.ui_value = '*'
193 197
194 198 web4 = RhodeCodeUi()
195 199 web4.ui_section = 'web'
196 200 web4.ui_key = 'baseurl'
197 201 web4.ui_value = '/'
198 202
199 203 paths = RhodeCodeUi()
200 204 paths.ui_section = 'paths'
201 205 paths.ui_key = '/'
202 206 paths.ui_value = path
203 207
204 208
205 209 hgsettings1 = RhodeCodeSettings('realm', 'RhodeCode authentication')
206 210 hgsettings2 = RhodeCodeSettings('title', 'RhodeCode')
207 211
208 212
209 213 try:
210 214 self.sa.add(hooks1)
211 215 self.sa.add(hooks2)
212 216 self.sa.add(hooks3)
213 217 self.sa.add(hooks4)
214 218 self.sa.add(web1)
215 219 self.sa.add(web2)
216 220 self.sa.add(web3)
217 221 self.sa.add(web4)
218 222 self.sa.add(paths)
219 223 self.sa.add(hgsettings1)
220 224 self.sa.add(hgsettings2)
221 225 self.sa.add(dotencode_disable)
222 226 for k in ['ldap_active', 'ldap_host', 'ldap_port', 'ldap_ldaps',
223 227 'ldap_dn_user', 'ldap_dn_pass', 'ldap_base_dn']:
224 228
225 229 setting = RhodeCodeSettings(k, '')
226 230 self.sa.add(setting)
227 231
228 232 self.sa.commit()
229 233 except:
230 234 self.sa.rollback()
231 235 raise
232 236 log.info('created ui config')
233 237
234 238 def create_user(self, username, password, email='', admin=False):
235 239 log.info('creating administrator user %s', username)
236 240 new_user = User()
237 241 new_user.username = username
238 242 new_user.password = get_crypt_password(password)
239 243 new_user.name = 'RhodeCode'
240 244 new_user.lastname = 'Admin'
241 245 new_user.email = email
242 246 new_user.admin = admin
243 247 new_user.active = True
244 248
245 249 try:
246 250 self.sa.add(new_user)
247 251 self.sa.commit()
248 252 except:
249 253 self.sa.rollback()
250 254 raise
251 255
252 256 def create_default_user(self):
253 257 log.info('creating default user')
254 258 #create default user for handling default permissions.
255 259 def_user = User()
256 260 def_user.username = 'default'
257 261 def_user.password = get_crypt_password(str(uuid.uuid1())[:8])
258 262 def_user.name = 'Anonymous'
259 263 def_user.lastname = 'User'
260 264 def_user.email = 'anonymous@rhodecode.org'
261 265 def_user.admin = False
262 266 def_user.active = False
263 267 try:
264 268 self.sa.add(def_user)
265 269 self.sa.commit()
266 270 except:
267 271 self.sa.rollback()
268 272 raise
269 273
270 274 def create_permissions(self):
271 275 #module.(access|create|change|delete)_[name]
272 276 #module.(read|write|owner)
273 277 perms = [('repository.none', 'Repository no access'),
274 278 ('repository.read', 'Repository read access'),
275 279 ('repository.write', 'Repository write access'),
276 280 ('repository.admin', 'Repository admin access'),
277 281 ('hg.admin', 'Hg Administrator'),
278 282 ('hg.create.repository', 'Repository create'),
279 283 ('hg.create.none', 'Repository creation disabled'),
280 284 ('hg.register.none', 'Register disabled'),
281 285 ('hg.register.manual_activate', 'Register new user with rhodecode without manual activation'),
282 286 ('hg.register.auto_activate', 'Register new user with rhodecode without auto activation'),
283 287 ]
284 288
285 289 for p in perms:
286 290 new_perm = Permission()
287 291 new_perm.permission_name = p[0]
288 292 new_perm.permission_longname = p[1]
289 293 try:
290 294 self.sa.add(new_perm)
291 295 self.sa.commit()
292 296 except:
293 297 self.sa.rollback()
294 298 raise
295 299
296 300 def populate_default_permissions(self):
297 301 log.info('creating default user permissions')
298 302
299 303 default_user = self.sa.query(User)\
300 304 .filter(User.username == 'default').scalar()
301 305
302 306 reg_perm = UserToPerm()
303 307 reg_perm.user = default_user
304 308 reg_perm.permission = self.sa.query(Permission)\
305 309 .filter(Permission.permission_name == 'hg.register.manual_activate')\
306 310 .scalar()
307 311
308 312 create_repo_perm = UserToPerm()
309 313 create_repo_perm.user = default_user
310 314 create_repo_perm.permission = self.sa.query(Permission)\
311 315 .filter(Permission.permission_name == 'hg.create.repository')\
312 316 .scalar()
313 317
314 318 default_repo_perm = UserToPerm()
315 319 default_repo_perm.user = default_user
316 320 default_repo_perm.permission = self.sa.query(Permission)\
317 321 .filter(Permission.permission_name == 'repository.read')\
318 322 .scalar()
319 323
320 324 try:
321 325 self.sa.add(reg_perm)
322 326 self.sa.add(create_repo_perm)
323 327 self.sa.add(default_repo_perm)
324 328 self.sa.commit()
325 329 except:
326 330 self.sa.rollback()
327 331 raise
328 332
@@ -1,59 +1,77 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 rhodecode.lib.dbmigrate.__init__
4 4 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5 5
6 6 Database migration modules
7 7
8 8 :created_on: Dec 11, 2010
9 9 :author: marcink
10 10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 11 :license: GPLv3, see COPYING for more details.
12 12 """
13 13 # This program is free software; you can redistribute it and/or
14 14 # modify it under the terms of the GNU General Public License
15 15 # as published by the Free Software Foundation; version 2
16 16 # of the License or (at your opinion) any later version of the license.
17 17 #
18 18 # This program is distributed in the hope that it will be useful,
19 19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 21 # GNU General Public License for more details.
22 22 #
23 23 # You should have received a copy of the GNU General Public License
24 24 # along with this program; if not, write to the Free Software
25 25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
26 26 # MA 02110-1301, USA.
27 27
28 from rhodecode.lib.utils import BasePasterCommand
28 import logging
29 from sqlalchemy import engine_from_config
30
31 from rhodecode.lib.dbmigrate.migrate.exceptions import \
32 DatabaseNotControlledError
29 33 from rhodecode.lib.utils import BasePasterCommand, Command, add_cache
30 34
31 from sqlalchemy import engine_from_config
35 log = logging.getLogger(__name__)
32 36
33 37 class UpgradeDb(BasePasterCommand):
34 38 """Command used for paster to upgrade our database to newer version
35 39 """
36 40
37 41 max_args = 1
38 42 min_args = 1
39 43
40 44 usage = "CONFIG_FILE"
41 45 summary = "Upgrades current db to newer version given configuration file"
42 46 group_name = "RhodeCode"
43 47
44 48 parser = Command.standard_parser(verbose=True)
45 49
46 50 def command(self):
47 51 from pylons import config
48 52 add_cache(config)
49 engine = engine_from_config(config, 'sqlalchemy.db1.')
50 print engine
51 raise NotImplementedError('Not implemented yet')
53 #engine = engine_from_config(config, 'sqlalchemy.db1.')
54 #rint engine
55
56 from rhodecode.lib.dbmigrate.migrate.versioning import api
57 path = 'rhodecode/lib/dbmigrate'
58
52 59
60 try:
61 curr_version = api.db_version(config['sqlalchemy.db1.url'], path)
62 msg = ('Found current database under version'
63 ' control with version %s' % curr_version)
64
65 except (RuntimeError, DatabaseNotControlledError), e:
66 curr_version = 0
67 msg = ('Current database is not under version control setting'
68 ' as version %s' % curr_version)
69
70 print msg
53 71
54 72 def update_parser(self):
55 73 self.parser.add_option('--sql',
56 74 action='store_true',
57 75 dest='just_sql',
58 76 help="Prints upgrade sql for further investigation",
59 77 default=False)
@@ -1,9 +1,9 b''
1 1 """
2 2 SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for
3 3 database schema version and repository management and
4 4 :mod:`migrate.changeset` that allows to define database schema changes
5 5 using Python.
6 6 """
7 7
8 from migrate.versioning import *
9 from migrate.changeset import *
8 from rhodecode.lib.dbmigrate.migrate.versioning import *
9 from rhodecode.lib.dbmigrate.migrate.changeset import *
@@ -1,28 +1,28 b''
1 1 """
2 2 This module extends SQLAlchemy and provides additional DDL [#]_
3 3 support.
4 4
5 5 .. [#] SQL Data Definition Language
6 6 """
7 7 import re
8 8 import warnings
9 9
10 10 import sqlalchemy
11 11 from sqlalchemy import __version__ as _sa_version
12 12
13 13 warnings.simplefilter('always', DeprecationWarning)
14 14
15 15 _sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split("."))
16 16 SQLA_06 = _sa_version >= (0, 6)
17 17
18 18 del re
19 19 del _sa_version
20 20
21 from migrate.changeset.schema import *
22 from migrate.changeset.constraint import *
21 from rhodecode.lib.dbmigrate.migrate.changeset.schema import *
22 from rhodecode.lib.dbmigrate.migrate.changeset.constraint import *
23 23
24 24 sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
25 25 sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
26 26 sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
27 27
28 28 sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause, )
@@ -1,358 +1,358 b''
1 1 """
2 2 Extensions to SQLAlchemy for altering existing tables.
3 3
4 4 At the moment, this isn't so much based off of ANSI as much as
5 5 things that just happen to work with multiple databases.
6 6 """
7 7 import StringIO
8 8
9 9 import sqlalchemy as sa
10 10 from sqlalchemy.schema import SchemaVisitor
11 11 from sqlalchemy.engine.default import DefaultDialect
12 12 from sqlalchemy.sql import ClauseElement
13 13 from sqlalchemy.schema import (ForeignKeyConstraint,
14 14 PrimaryKeyConstraint,
15 15 CheckConstraint,
16 16 UniqueConstraint,
17 17 Index)
18 18
19 from migrate import exceptions
20 from migrate.changeset import constraint, SQLA_06
19 from rhodecode.lib.dbmigrate.migrate import exceptions
20 from rhodecode.lib.dbmigrate.migrate.changeset import constraint, SQLA_06
21 21
22 22 if not SQLA_06:
23 23 from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
24 24 else:
25 25 from sqlalchemy.schema import AddConstraint, DropConstraint
26 26 from sqlalchemy.sql.compiler import DDLCompiler
27 27 SchemaGenerator = SchemaDropper = DDLCompiler
28 28
29 29
30 30 class AlterTableVisitor(SchemaVisitor):
31 31 """Common operations for ``ALTER TABLE`` statements."""
32 32
33 33 if SQLA_06:
34 34 # engine.Compiler looks for .statement
35 35 # when it spawns off a new compiler
36 36 statement = ClauseElement()
37 37
38 38 def append(self, s):
39 39 """Append content to the SchemaIterator's query buffer."""
40 40
41 41 self.buffer.write(s)
42 42
43 43 def execute(self):
44 44 """Execute the contents of the SchemaIterator's buffer."""
45 45 try:
46 46 return self.connection.execute(self.buffer.getvalue())
47 47 finally:
48 48 self.buffer.truncate(0)
49 49
50 50 def __init__(self, dialect, connection, **kw):
51 51 self.connection = connection
52 52 self.buffer = StringIO.StringIO()
53 53 self.preparer = dialect.identifier_preparer
54 54 self.dialect = dialect
55 55
56 56 def traverse_single(self, elem):
57 57 ret = super(AlterTableVisitor, self).traverse_single(elem)
58 58 if ret:
59 59 # adapt to 0.6 which uses a string-returning
60 60 # object
61 61 self.append(" %s" % ret)
62 62
63 63 def _to_table(self, param):
64 64 """Returns the table object for the given param object."""
65 65 if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
66 66 ret = param.table
67 67 else:
68 68 ret = param
69 69 return ret
70 70
71 71 def start_alter_table(self, param):
72 72 """Returns the start of an ``ALTER TABLE`` SQL-Statement.
73 73
74 74 Use the param object to determine the table name and use it
75 75 for building the SQL statement.
76 76
77 77 :param param: object to determine the table from
78 78 :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
79 79 :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
80 80 or string (table name)
81 81 """
82 82 table = self._to_table(param)
83 83 self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
84 84 return table
85 85
86 86
87 87 class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
88 88 """Extends ansisql generator for column creation (alter table add col)"""
89 89
90 90 def visit_column(self, column):
91 91 """Create a column (table already exists).
92 92
93 93 :param column: column object
94 94 :type column: :class:`sqlalchemy.Column` instance
95 95 """
96 96 if column.default is not None:
97 97 self.traverse_single(column.default)
98 98
99 99 table = self.start_alter_table(column)
100 100 self.append("ADD ")
101 101 self.append(self.get_column_specification(column))
102 102
103 103 for cons in column.constraints:
104 104 self.traverse_single(cons)
105 105 self.execute()
106 106
107 107 # ALTER TABLE STATEMENTS
108 108
109 109 # add indexes and unique constraints
110 110 if column.index_name:
111 111 Index(column.index_name,column).create()
112 112 elif column.unique_name:
113 113 constraint.UniqueConstraint(column,
114 114 name=column.unique_name).create()
115 115
116 116 # SA bounds FK constraints to table, add manually
117 117 for fk in column.foreign_keys:
118 118 self.add_foreignkey(fk.constraint)
119 119
120 120 # add primary key constraint if needed
121 121 if column.primary_key_name:
122 122 cons = constraint.PrimaryKeyConstraint(column,
123 123 name=column.primary_key_name)
124 124 cons.create()
125 125
126 126 if SQLA_06:
127 127 def add_foreignkey(self, fk):
128 128 self.connection.execute(AddConstraint(fk))
129 129
130 130 class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
131 131 """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
132 132 DROP COLUMN``).
133 133 """
134 134
135 135 def visit_column(self, column):
136 136 """Drop a column from its table.
137 137
138 138 :param column: the column object
139 139 :type column: :class:`sqlalchemy.Column`
140 140 """
141 141 table = self.start_alter_table(column)
142 142 self.append('DROP COLUMN %s' % self.preparer.format_column(column))
143 143 self.execute()
144 144
145 145
146 146 class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
147 147 """Manages changes to existing schema elements.
148 148
149 149 Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
150 150 is in SchemaGenerator.
151 151
152 152 All items may be renamed. Columns can also have many of their properties -
153 153 type, for example - changed.
154 154
155 155 Each function is passed a tuple, containing (object, name); where
156 156 object is a type of object you'd expect for that function
157 157 (ie. table for visit_table) and name is the object's new
158 158 name. NONE means the name is unchanged.
159 159 """
160 160
161 161 def visit_table(self, table):
162 162 """Rename a table. Other ops aren't supported."""
163 163 self.start_alter_table(table)
164 164 self.append("RENAME TO %s" % self.preparer.quote(table.new_name,
165 165 table.quote))
166 166 self.execute()
167 167
168 168 def visit_index(self, index):
169 169 """Rename an index"""
170 170 if hasattr(self, '_validate_identifier'):
171 171 # SA <= 0.6.3
172 172 self.append("ALTER INDEX %s RENAME TO %s" % (
173 173 self.preparer.quote(
174 174 self._validate_identifier(
175 175 index.name, True), index.quote),
176 176 self.preparer.quote(
177 177 self._validate_identifier(
178 178 index.new_name, True), index.quote)))
179 179 else:
180 180 # SA >= 0.6.5
181 181 self.append("ALTER INDEX %s RENAME TO %s" % (
182 182 self.preparer.quote(
183 183 self._index_identifier(
184 184 index.name), index.quote),
185 185 self.preparer.quote(
186 186 self._index_identifier(
187 187 index.new_name), index.quote)))
188 188 self.execute()
189 189
190 190 def visit_column(self, delta):
191 191 """Rename/change a column."""
192 192 # ALTER COLUMN is implemented as several ALTER statements
193 193 keys = delta.keys()
194 194 if 'type' in keys:
195 195 self._run_subvisit(delta, self._visit_column_type)
196 196 if 'nullable' in keys:
197 197 self._run_subvisit(delta, self._visit_column_nullable)
198 198 if 'server_default' in keys:
199 199 # Skip 'default': only handle server-side defaults, others
200 200 # are managed by the app, not the db.
201 201 self._run_subvisit(delta, self._visit_column_default)
202 202 if 'name' in keys:
203 203 self._run_subvisit(delta, self._visit_column_name, start_alter=False)
204 204
205 205 def _run_subvisit(self, delta, func, start_alter=True):
206 206 """Runs visit method based on what needs to be changed on column"""
207 207 table = self._to_table(delta.table)
208 208 col_name = delta.current_name
209 209 if start_alter:
210 210 self.start_alter_column(table, col_name)
211 211 ret = func(table, delta.result_column, delta)
212 212 self.execute()
213 213
214 214 def start_alter_column(self, table, col_name):
215 215 """Starts ALTER COLUMN"""
216 216 self.start_alter_table(table)
217 217 self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, table.quote))
218 218
219 219 def _visit_column_nullable(self, table, column, delta):
220 220 nullable = delta['nullable']
221 221 if nullable:
222 222 self.append("DROP NOT NULL")
223 223 else:
224 224 self.append("SET NOT NULL")
225 225
226 226 def _visit_column_default(self, table, column, delta):
227 227 default_text = self.get_column_default_string(column)
228 228 if default_text is not None:
229 229 self.append("SET DEFAULT %s" % default_text)
230 230 else:
231 231 self.append("DROP DEFAULT")
232 232
233 233 def _visit_column_type(self, table, column, delta):
234 234 type_ = delta['type']
235 235 if SQLA_06:
236 236 type_text = str(type_.compile(dialect=self.dialect))
237 237 else:
238 238 type_text = type_.dialect_impl(self.dialect).get_col_spec()
239 239 self.append("TYPE %s" % type_text)
240 240
241 241 def _visit_column_name(self, table, column, delta):
242 242 self.start_alter_table(table)
243 243 col_name = self.preparer.quote(delta.current_name, table.quote)
244 244 new_name = self.preparer.format_column(delta.result_column)
245 245 self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
246 246
247 247
248 248 class ANSIConstraintCommon(AlterTableVisitor):
249 249 """
250 250 Migrate's constraints require a separate creation function from
251 251 SA's: Migrate's constraints are created independently of a table;
252 252 SA's are created at the same time as the table.
253 253 """
254 254
255 255 def get_constraint_name(self, cons):
256 256 """Gets a name for the given constraint.
257 257
258 258 If the name is already set it will be used otherwise the
259 259 constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
260 260 method is used.
261 261
262 262 :param cons: constraint object
263 263 """
264 264 if cons.name is not None:
265 265 ret = cons.name
266 266 else:
267 267 ret = cons.name = cons.autoname()
268 268 return self.preparer.quote(ret, cons.quote)
269 269
270 270 def visit_migrate_primary_key_constraint(self, *p, **k):
271 271 self._visit_constraint(*p, **k)
272 272
273 273 def visit_migrate_foreign_key_constraint(self, *p, **k):
274 274 self._visit_constraint(*p, **k)
275 275
276 276 def visit_migrate_check_constraint(self, *p, **k):
277 277 self._visit_constraint(*p, **k)
278 278
279 279 def visit_migrate_unique_constraint(self, *p, **k):
280 280 self._visit_constraint(*p, **k)
281 281
282 282 if SQLA_06:
283 283 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
284 284 def _visit_constraint(self, constraint):
285 285 constraint.name = self.get_constraint_name(constraint)
286 286 self.append(self.process(AddConstraint(constraint)))
287 287 self.execute()
288 288
289 289 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
290 290 def _visit_constraint(self, constraint):
291 291 constraint.name = self.get_constraint_name(constraint)
292 292 self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
293 293 self.execute()
294 294
295 295 else:
296 296 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
297 297
298 298 def get_constraint_specification(self, cons, **kwargs):
299 299 """Constaint SQL generators.
300 300
301 301 We cannot use SA visitors because they append comma.
302 302 """
303 303
304 304 if isinstance(cons, PrimaryKeyConstraint):
305 305 if cons.name is not None:
306 306 self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
307 307 self.append("PRIMARY KEY ")
308 308 self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
309 309 for c in cons))
310 310 self.define_constraint_deferrability(cons)
311 311 elif isinstance(cons, ForeignKeyConstraint):
312 312 self.define_foreign_key(cons)
313 313 elif isinstance(cons, CheckConstraint):
314 314 if cons.name is not None:
315 315 self.append("CONSTRAINT %s " %
316 316 self.preparer.format_constraint(cons))
317 317 self.append("CHECK (%s)" % cons.sqltext)
318 318 self.define_constraint_deferrability(cons)
319 319 elif isinstance(cons, UniqueConstraint):
320 320 if cons.name is not None:
321 321 self.append("CONSTRAINT %s " %
322 322 self.preparer.format_constraint(cons))
323 323 self.append("UNIQUE (%s)" % \
324 324 (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
325 325 self.define_constraint_deferrability(cons)
326 326 else:
327 327 raise exceptions.InvalidConstraintError(cons)
328 328
329 329 def _visit_constraint(self, constraint):
330 330
331 331 table = self.start_alter_table(constraint)
332 332 constraint.name = self.get_constraint_name(constraint)
333 333 self.append("ADD ")
334 334 self.get_constraint_specification(constraint)
335 335 self.execute()
336 336
337 337
338 338 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
339 339
340 340 def _visit_constraint(self, constraint):
341 341 self.start_alter_table(constraint)
342 342 self.append("DROP CONSTRAINT ")
343 343 constraint.name = self.get_constraint_name(constraint)
344 344 self.append(self.preparer.format_constraint(constraint))
345 345 if constraint.cascade:
346 346 self.cascade_constraint(constraint)
347 347 self.execute()
348 348
349 349 def cascade_constraint(self, constraint):
350 350 self.append(" CASCADE")
351 351
352 352
353 353 class ANSIDialect(DefaultDialect):
354 354 columngenerator = ANSIColumnGenerator
355 355 columndropper = ANSIColumnDropper
356 356 schemachanger = ANSISchemaChanger
357 357 constraintgenerator = ANSIConstraintGenerator
358 358 constraintdropper = ANSIConstraintDropper
@@ -1,202 +1,202 b''
1 1 """
2 2 This module defines standalone schema constraint classes.
3 3 """
4 4 from sqlalchemy import schema
5 5
6 from migrate.exceptions import *
7 from migrate.changeset import SQLA_06
6 from rhodecode.lib.dbmigrate.migrate.exceptions import *
7 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
8 8
9 9 class ConstraintChangeset(object):
10 10 """Base class for Constraint classes."""
11 11
12 12 def _normalize_columns(self, cols, table_name=False):
13 13 """Given: column objects or names; return col names and
14 14 (maybe) a table"""
15 15 colnames = []
16 16 table = None
17 17 for col in cols:
18 18 if isinstance(col, schema.Column):
19 19 if col.table is not None and table is None:
20 20 table = col.table
21 21 if table_name:
22 22 col = '.'.join((col.table.name, col.name))
23 23 else:
24 24 col = col.name
25 25 colnames.append(col)
26 26 return colnames, table
27 27
28 28 def __do_imports(self, visitor_name, *a, **kw):
29 29 engine = kw.pop('engine', self.table.bind)
30 from migrate.changeset.databases.visitor import (get_engine_visitor,
30 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor,
31 31 run_single_visitor)
32 32 visitorcallable = get_engine_visitor(engine, visitor_name)
33 33 run_single_visitor(engine, visitorcallable, self, *a, **kw)
34 34
35 35 def create(self, *a, **kw):
36 36 """Create the constraint in the database.
37 37
38 38 :param engine: the database engine to use. If this is \
39 39 :keyword:`None` the instance's engine will be used
40 40 :type engine: :class:`sqlalchemy.engine.base.Engine`
41 41 :param connection: reuse connection istead of creating new one.
42 42 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
43 43 """
44 44 # TODO: set the parent here instead of in __init__
45 45 self.__do_imports('constraintgenerator', *a, **kw)
46 46
47 47 def drop(self, *a, **kw):
48 48 """Drop the constraint from the database.
49 49
50 50 :param engine: the database engine to use. If this is
51 51 :keyword:`None` the instance's engine will be used
52 52 :param cascade: Issue CASCADE drop if database supports it
53 53 :type engine: :class:`sqlalchemy.engine.base.Engine`
54 54 :type cascade: bool
55 55 :param connection: reuse connection istead of creating new one.
56 56 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
57 57 :returns: Instance with cleared columns
58 58 """
59 59 self.cascade = kw.pop('cascade', False)
60 60 self.__do_imports('constraintdropper', *a, **kw)
61 61 # the spirit of Constraint objects is that they
62 62 # are immutable (just like in a DB. they're only ADDed
63 63 # or DROPped).
64 64 #self.columns.clear()
65 65 return self
66 66
67 67
68 68 class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
69 69 """Construct PrimaryKeyConstraint
70 70
71 71 Migrate's additional parameters:
72 72
73 73 :param cols: Columns in constraint.
74 74 :param table: If columns are passed as strings, this kw is required
75 75 :type table: Table instance
76 76 :type cols: strings or Column instances
77 77 """
78 78
79 79 __migrate_visit_name__ = 'migrate_primary_key_constraint'
80 80
81 81 def __init__(self, *cols, **kwargs):
82 82 colnames, table = self._normalize_columns(cols)
83 83 table = kwargs.pop('table', table)
84 84 super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs)
85 85 if table is not None:
86 86 self._set_parent(table)
87 87
88 88
89 89 def autoname(self):
90 90 """Mimic the database's automatic constraint names"""
91 91 return "%s_pkey" % self.table.name
92 92
93 93
94 94 class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
95 95 """Construct ForeignKeyConstraint
96 96
97 97 Migrate's additional parameters:
98 98
99 99 :param columns: Columns in constraint
100 100 :param refcolumns: Columns that this FK reffers to in another table.
101 101 :param table: If columns are passed as strings, this kw is required
102 102 :type table: Table instance
103 103 :type columns: list of strings or Column instances
104 104 :type refcolumns: list of strings or Column instances
105 105 """
106 106
107 107 __migrate_visit_name__ = 'migrate_foreign_key_constraint'
108 108
109 109 def __init__(self, columns, refcolumns, *args, **kwargs):
110 110 colnames, table = self._normalize_columns(columns)
111 111 table = kwargs.pop('table', table)
112 112 refcolnames, reftable = self._normalize_columns(refcolumns,
113 113 table_name=True)
114 114 super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *args,
115 115 **kwargs)
116 116 if table is not None:
117 117 self._set_parent(table)
118 118
119 119 @property
120 120 def referenced(self):
121 121 return [e.column for e in self.elements]
122 122
123 123 @property
124 124 def reftable(self):
125 125 return self.referenced[0].table
126 126
127 127 def autoname(self):
128 128 """Mimic the database's automatic constraint names"""
129 129 if hasattr(self.columns, 'keys'):
130 130 # SA <= 0.5
131 131 firstcol = self.columns[self.columns.keys()[0]]
132 132 ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
133 133 table=firstcol.table.name,
134 134 firstcolumn=firstcol.name,)
135 135 else:
136 136 # SA >= 0.6
137 137 ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
138 138 table=self.table.name,
139 139 firstcolumn=self.columns[0],)
140 140 return ret
141 141
142 142
143 143 class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
144 144 """Construct CheckConstraint
145 145
146 146 Migrate's additional parameters:
147 147
148 148 :param sqltext: Plain SQL text to check condition
149 149 :param columns: If not name is applied, you must supply this kw\
150 150 to autoname constraint
151 151 :param table: If columns are passed as strings, this kw is required
152 152 :type table: Table instance
153 153 :type columns: list of Columns instances
154 154 :type sqltext: string
155 155 """
156 156
157 157 __migrate_visit_name__ = 'migrate_check_constraint'
158 158
159 159 def __init__(self, sqltext, *args, **kwargs):
160 160 cols = kwargs.pop('columns', [])
161 161 if not cols and not kwargs.get('name', False):
162 162 raise InvalidConstraintError('You must either set "name"'
163 163 'parameter or "columns" to autogenarate it.')
164 164 colnames, table = self._normalize_columns(cols)
165 165 table = kwargs.pop('table', table)
166 166 schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
167 167 if table is not None:
168 168 if not SQLA_06:
169 169 self.table = table
170 170 self._set_parent(table)
171 171 self.colnames = colnames
172 172
173 173 def autoname(self):
174 174 return "%(table)s_%(cols)s_check" % \
175 175 dict(table=self.table.name, cols="_".join(self.colnames))
176 176
177 177
178 178 class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
179 179 """Construct UniqueConstraint
180 180
181 181 Migrate's additional parameters:
182 182
183 183 :param cols: Columns in constraint.
184 184 :param table: If columns are passed as strings, this kw is required
185 185 :type table: Table instance
186 186 :type cols: strings or Column instances
187 187
188 188 .. versionadded:: 0.6.0
189 189 """
190 190
191 191 __migrate_visit_name__ = 'migrate_unique_constraint'
192 192
193 193 def __init__(self, *cols, **kwargs):
194 194 self.colnames, table = self._normalize_columns(cols)
195 195 table = kwargs.pop('table', table)
196 196 super(UniqueConstraint, self).__init__(*self.colnames, **kwargs)
197 197 if table is not None:
198 198 self._set_parent(table)
199 199
200 200 def autoname(self):
201 201 """Mimic the database's automatic constraint names"""
202 202 return "%s_%s_key" % (self.table.name, self.colnames[0])
@@ -1,80 +1,80 b''
1 1 """
2 2 Firebird database specific implementations of changeset classes.
3 3 """
4 4 from sqlalchemy.databases import firebird as sa_base
5 5
6 from migrate import exceptions
7 from migrate.changeset import ansisql, SQLA_06
6 from rhodecode.lib.dbmigrate.migrate import exceptions
7 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
8 8
9 9
10 10 if SQLA_06:
11 11 FBSchemaGenerator = sa_base.FBDDLCompiler
12 12 else:
13 13 FBSchemaGenerator = sa_base.FBSchemaGenerator
14 14
15 15 class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
16 16 """Firebird column generator implementation."""
17 17
18 18
19 19 class FBColumnDropper(ansisql.ANSIColumnDropper):
20 20 """Firebird column dropper implementation."""
21 21
22 22 def visit_column(self, column):
23 23 """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax
24 24
25 25 Drop primary key and unique constraints if dropped column is referencing it."""
26 26 if column.primary_key:
27 27 if column.table.primary_key.columns.contains_column(column):
28 28 column.table.primary_key.drop()
29 29 # TODO: recreate primary key if it references more than this column
30 30 if column.unique or getattr(column, 'unique_name', None):
31 31 for cons in column.table.constraints:
32 32 if cons.contains_column(column):
33 33 cons.drop()
34 34 # TODO: recreate unique constraint if it refenrences more than this column
35 35
36 36 table = self.start_alter_table(column)
37 37 self.append('DROP %s' % self.preparer.format_column(column))
38 38 self.execute()
39 39
40 40
41 41 class FBSchemaChanger(ansisql.ANSISchemaChanger):
42 42 """Firebird schema changer implementation."""
43 43
44 44 def visit_table(self, table):
45 45 """Rename table not supported"""
46 46 raise exceptions.NotSupportedError(
47 47 "Firebird does not support renaming tables.")
48 48
49 49 def _visit_column_name(self, table, column, delta):
50 50 self.start_alter_table(table)
51 51 col_name = self.preparer.quote(delta.current_name, table.quote)
52 52 new_name = self.preparer.format_column(delta.result_column)
53 53 self.append('ALTER COLUMN %s TO %s' % (col_name, new_name))
54 54
55 55 def _visit_column_nullable(self, table, column, delta):
56 56 """Changing NULL is not supported"""
57 57 # TODO: http://www.firebirdfaq.org/faq103/
58 58 raise exceptions.NotSupportedError(
59 59 "Firebird does not support altering NULL bevahior.")
60 60
61 61
62 62 class FBConstraintGenerator(ansisql.ANSIConstraintGenerator):
63 63 """Firebird constraint generator implementation."""
64 64
65 65
66 66 class FBConstraintDropper(ansisql.ANSIConstraintDropper):
67 67 """Firebird constaint dropper implementation."""
68 68
69 69 def cascade_constraint(self, constraint):
70 70 """Cascading constraints is not supported"""
71 71 raise exceptions.NotSupportedError(
72 72 "Firebird does not support cascading constraints")
73 73
74 74
75 75 class FBDialect(ansisql.ANSIDialect):
76 76 columngenerator = FBColumnGenerator
77 77 columndropper = FBColumnDropper
78 78 schemachanger = FBSchemaChanger
79 79 constraintgenerator = FBConstraintGenerator
80 80 constraintdropper = FBConstraintDropper
@@ -1,94 +1,94 b''
1 1 """
2 2 MySQL database specific implementations of changeset classes.
3 3 """
4 4
5 5 from sqlalchemy.databases import mysql as sa_base
6 6 from sqlalchemy import types as sqltypes
7 7
8 from migrate import exceptions
9 from migrate.changeset import ansisql, SQLA_06
8 from rhodecode.lib.dbmigrate.migrate import exceptions
9 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
10 10
11 11
12 12 if not SQLA_06:
13 13 MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator
14 14 else:
15 15 MySQLSchemaGenerator = sa_base.MySQLDDLCompiler
16 16
17 17 class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator):
18 18 pass
19 19
20 20
21 21 class MySQLColumnDropper(ansisql.ANSIColumnDropper):
22 22 pass
23 23
24 24
25 25 class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger):
26 26
27 27 def visit_column(self, delta):
28 28 table = delta.table
29 29 colspec = self.get_column_specification(delta.result_column)
30 30 if delta.result_column.autoincrement:
31 31 primary_keys = [c for c in table.primary_key.columns
32 32 if (c.autoincrement and
33 33 isinstance(c.type, sqltypes.Integer) and
34 34 not c.foreign_keys)]
35 35
36 36 if primary_keys:
37 37 first = primary_keys.pop(0)
38 38 if first.name == delta.current_name:
39 39 colspec += " AUTO_INCREMENT"
40 40 old_col_name = self.preparer.quote(delta.current_name, table.quote)
41 41
42 42 self.start_alter_table(table)
43 43
44 44 self.append("CHANGE COLUMN %s " % old_col_name)
45 45 self.append(colspec)
46 46 self.execute()
47 47
48 48 def visit_index(self, param):
49 49 # If MySQL can do this, I can't find how
50 50 raise exceptions.NotSupportedError("MySQL cannot rename indexes")
51 51
52 52
53 53 class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator):
54 54 pass
55 55
56 56 if SQLA_06:
57 57 class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper):
58 58 def visit_migrate_check_constraint(self, *p, **k):
59 59 raise exceptions.NotSupportedError("MySQL does not support CHECK"
60 60 " constraints, use triggers instead.")
61 61
62 62 else:
63 63 class MySQLConstraintDropper(ansisql.ANSIConstraintDropper):
64 64
65 65 def visit_migrate_primary_key_constraint(self, constraint):
66 66 self.start_alter_table(constraint)
67 67 self.append("DROP PRIMARY KEY")
68 68 self.execute()
69 69
70 70 def visit_migrate_foreign_key_constraint(self, constraint):
71 71 self.start_alter_table(constraint)
72 72 self.append("DROP FOREIGN KEY ")
73 73 constraint.name = self.get_constraint_name(constraint)
74 74 self.append(self.preparer.format_constraint(constraint))
75 75 self.execute()
76 76
77 77 def visit_migrate_check_constraint(self, *p, **k):
78 78 raise exceptions.NotSupportedError("MySQL does not support CHECK"
79 79 " constraints, use triggers instead.")
80 80
81 81 def visit_migrate_unique_constraint(self, constraint, *p, **k):
82 82 self.start_alter_table(constraint)
83 83 self.append('DROP INDEX ')
84 84 constraint.name = self.get_constraint_name(constraint)
85 85 self.append(self.preparer.format_constraint(constraint))
86 86 self.execute()
87 87
88 88
89 89 class MySQLDialect(ansisql.ANSIDialect):
90 90 columngenerator = MySQLColumnGenerator
91 91 columndropper = MySQLColumnDropper
92 92 schemachanger = MySQLSchemaChanger
93 93 constraintgenerator = MySQLConstraintGenerator
94 94 constraintdropper = MySQLConstraintDropper
@@ -1,111 +1,111 b''
1 1 """
2 2 Oracle database specific implementations of changeset classes.
3 3 """
4 4 import sqlalchemy as sa
5 5 from sqlalchemy.databases import oracle as sa_base
6 6
7 from migrate import exceptions
8 from migrate.changeset import ansisql, SQLA_06
7 from rhodecode.lib.dbmigrate.migrate import exceptions
8 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
9 9
10 10
11 11 if not SQLA_06:
12 12 OracleSchemaGenerator = sa_base.OracleSchemaGenerator
13 13 else:
14 14 OracleSchemaGenerator = sa_base.OracleDDLCompiler
15 15
16 16
17 17 class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
18 18 pass
19 19
20 20
21 21 class OracleColumnDropper(ansisql.ANSIColumnDropper):
22 22 pass
23 23
24 24
25 25 class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger):
26 26
27 27 def get_column_specification(self, column, **kwargs):
28 28 # Ignore the NOT NULL generated
29 29 override_nullable = kwargs.pop('override_nullable', None)
30 30 if override_nullable:
31 31 orig = column.nullable
32 32 column.nullable = True
33 33 ret = super(OracleSchemaChanger, self).get_column_specification(
34 34 column, **kwargs)
35 35 if override_nullable:
36 36 column.nullable = orig
37 37 return ret
38 38
39 39 def visit_column(self, delta):
40 40 keys = delta.keys()
41 41
42 42 if 'name' in keys:
43 43 self._run_subvisit(delta,
44 44 self._visit_column_name,
45 45 start_alter=False)
46 46
47 47 if len(set(('type', 'nullable', 'server_default')).intersection(keys)):
48 48 self._run_subvisit(delta,
49 49 self._visit_column_change,
50 50 start_alter=False)
51 51
52 52 def _visit_column_change(self, table, column, delta):
53 53 # Oracle cannot drop a default once created, but it can set it
54 54 # to null. We'll do that if default=None
55 55 # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234
56 56 dropdefault_hack = (column.server_default is None \
57 57 and 'server_default' in delta.keys())
58 58 # Oracle apparently doesn't like it when we say "not null" if
59 59 # the column's already not null. Fudge it, so we don't need a
60 60 # new function
61 61 notnull_hack = ((not column.nullable) \
62 62 and ('nullable' not in delta.keys()))
63 63 # We need to specify NULL if we're removing a NOT NULL
64 64 # constraint
65 65 null_hack = (column.nullable and ('nullable' in delta.keys()))
66 66
67 67 if dropdefault_hack:
68 68 column.server_default = sa.PassiveDefault(sa.sql.null())
69 69 if notnull_hack:
70 70 column.nullable = True
71 71 colspec = self.get_column_specification(column,
72 72 override_nullable=null_hack)
73 73 if null_hack:
74 74 colspec += ' NULL'
75 75 if notnull_hack:
76 76 column.nullable = False
77 77 if dropdefault_hack:
78 78 column.server_default = None
79 79
80 80 self.start_alter_table(table)
81 81 self.append("MODIFY (")
82 82 self.append(colspec)
83 83 self.append(")")
84 84
85 85
86 86 class OracleConstraintCommon(object):
87 87
88 88 def get_constraint_name(self, cons):
89 89 # Oracle constraints can't guess their name like other DBs
90 90 if not cons.name:
91 91 raise exceptions.NotSupportedError(
92 92 "Oracle constraint names must be explicitly stated")
93 93 return cons.name
94 94
95 95
96 96 class OracleConstraintGenerator(OracleConstraintCommon,
97 97 ansisql.ANSIConstraintGenerator):
98 98 pass
99 99
100 100
101 101 class OracleConstraintDropper(OracleConstraintCommon,
102 102 ansisql.ANSIConstraintDropper):
103 103 pass
104 104
105 105
106 106 class OracleDialect(ansisql.ANSIDialect):
107 107 columngenerator = OracleColumnGenerator
108 108 columndropper = OracleColumnDropper
109 109 schemachanger = OracleSchemaChanger
110 110 constraintgenerator = OracleConstraintGenerator
111 111 constraintdropper = OracleConstraintDropper
@@ -1,46 +1,46 b''
1 1 """
2 2 `PostgreSQL`_ database specific implementations of changeset classes.
3 3
4 4 .. _`PostgreSQL`: http://www.postgresql.org/
5 5 """
6 from migrate.changeset import ansisql, SQLA_06
6 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
7 7
8 8 if not SQLA_06:
9 9 from sqlalchemy.databases import postgres as sa_base
10 10 PGSchemaGenerator = sa_base.PGSchemaGenerator
11 11 else:
12 12 from sqlalchemy.databases import postgresql as sa_base
13 13 PGSchemaGenerator = sa_base.PGDDLCompiler
14 14
15 15
16 16 class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator):
17 17 """PostgreSQL column generator implementation."""
18 18 pass
19 19
20 20
21 21 class PGColumnDropper(ansisql.ANSIColumnDropper):
22 22 """PostgreSQL column dropper implementation."""
23 23 pass
24 24
25 25
26 26 class PGSchemaChanger(ansisql.ANSISchemaChanger):
27 27 """PostgreSQL schema changer implementation."""
28 28 pass
29 29
30 30
31 31 class PGConstraintGenerator(ansisql.ANSIConstraintGenerator):
32 32 """PostgreSQL constraint generator implementation."""
33 33 pass
34 34
35 35
36 36 class PGConstraintDropper(ansisql.ANSIConstraintDropper):
37 37 """PostgreSQL constaint dropper implementation."""
38 38 pass
39 39
40 40
41 41 class PGDialect(ansisql.ANSIDialect):
42 42 columngenerator = PGColumnGenerator
43 43 columndropper = PGColumnDropper
44 44 schemachanger = PGSchemaChanger
45 45 constraintgenerator = PGConstraintGenerator
46 46 constraintdropper = PGConstraintDropper
@@ -1,148 +1,148 b''
1 1 """
2 2 `SQLite`_ database specific implementations of changeset classes.
3 3
4 4 .. _`SQLite`: http://www.sqlite.org/
5 5 """
6 6 from UserDict import DictMixin
7 7 from copy import copy
8 8
9 9 from sqlalchemy.databases import sqlite as sa_base
10 10
11 from migrate import exceptions
12 from migrate.changeset import ansisql, SQLA_06
11 from rhodecode.lib.dbmigrate.migrate import exceptions
12 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
13 13
14 14
15 15 if not SQLA_06:
16 16 SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
17 17 else:
18 18 SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
19 19
20 20 class SQLiteCommon(object):
21 21
22 22 def _not_supported(self, op):
23 23 raise exceptions.NotSupportedError("SQLite does not support "
24 24 "%s; see http://www.sqlite.org/lang_altertable.html" % op)
25 25
26 26
27 27 class SQLiteHelper(SQLiteCommon):
28 28
29 29 def recreate_table(self,table,column=None,delta=None):
30 30 table_name = self.preparer.format_table(table)
31 31
32 32 # we remove all indexes so as not to have
33 33 # problems during copy and re-create
34 34 for index in table.indexes:
35 35 index.drop()
36 36
37 37 self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
38 38 self.execute()
39 39
40 40 insertion_string = self._modify_table(table, column, delta)
41 41
42 42 table.create()
43 43 self.append(insertion_string % {'table_name': table_name})
44 44 self.execute()
45 45 self.append('DROP TABLE migration_tmp')
46 46 self.execute()
47 47
48 48 def visit_column(self, delta):
49 49 if isinstance(delta, DictMixin):
50 50 column = delta.result_column
51 51 table = self._to_table(delta.table)
52 52 else:
53 53 column = delta
54 54 table = self._to_table(column.table)
55 55 self.recreate_table(table,column,delta)
56 56
57 57 class SQLiteColumnGenerator(SQLiteSchemaGenerator,
58 58 ansisql.ANSIColumnGenerator,
59 59 # at the end so we get the normal
60 60 # visit_column by default
61 61 SQLiteHelper,
62 62 SQLiteCommon
63 63 ):
64 64 """SQLite ColumnGenerator"""
65 65
66 66 def _modify_table(self, table, column, delta):
67 67 columns = ' ,'.join(map(
68 68 self.preparer.format_column,
69 69 [c for c in table.columns if c.name!=column.name]))
70 70 return ('INSERT INTO %%(table_name)s (%(cols)s) '
71 71 'SELECT %(cols)s from migration_tmp')%{'cols':columns}
72 72
73 73 def visit_column(self,column):
74 74 if column.foreign_keys:
75 75 SQLiteHelper.visit_column(self,column)
76 76 else:
77 77 super(SQLiteColumnGenerator,self).visit_column(column)
78 78
79 79 class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper):
80 80 """SQLite ColumnDropper"""
81 81
82 82 def _modify_table(self, table, column, delta):
83 83 columns = ' ,'.join(map(self.preparer.format_column, table.columns))
84 84 return 'INSERT INTO %(table_name)s SELECT ' + columns + \
85 85 ' from migration_tmp'
86 86
87 87
88 88 class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger):
89 89 """SQLite SchemaChanger"""
90 90
91 91 def _modify_table(self, table, column, delta):
92 92 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
93 93
94 94 def visit_index(self, index):
95 95 """Does not support ALTER INDEX"""
96 96 self._not_supported('ALTER INDEX')
97 97
98 98
99 99 class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon):
100 100
101 101 def visit_migrate_primary_key_constraint(self, constraint):
102 102 tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )"
103 103 cols = ', '.join(map(self.preparer.format_column, constraint.columns))
104 104 tname = self.preparer.format_table(constraint.table)
105 105 name = self.get_constraint_name(constraint)
106 106 msg = tmpl % (name, tname, cols)
107 107 self.append(msg)
108 108 self.execute()
109 109
110 110 def _modify_table(self, table, column, delta):
111 111 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
112 112
113 113 def visit_migrate_foreign_key_constraint(self, *p, **k):
114 114 self.recreate_table(p[0].table)
115 115
116 116 def visit_migrate_unique_constraint(self, *p, **k):
117 117 self.recreate_table(p[0].table)
118 118
119 119
120 120 class SQLiteConstraintDropper(ansisql.ANSIColumnDropper,
121 121 SQLiteCommon,
122 122 ansisql.ANSIConstraintCommon):
123 123
124 124 def visit_migrate_primary_key_constraint(self, constraint):
125 125 tmpl = "DROP INDEX %s "
126 126 name = self.get_constraint_name(constraint)
127 127 msg = tmpl % (name)
128 128 self.append(msg)
129 129 self.execute()
130 130
131 131 def visit_migrate_foreign_key_constraint(self, *p, **k):
132 132 self._not_supported('ALTER TABLE DROP CONSTRAINT')
133 133
134 134 def visit_migrate_check_constraint(self, *p, **k):
135 135 self._not_supported('ALTER TABLE DROP CONSTRAINT')
136 136
137 137 def visit_migrate_unique_constraint(self, *p, **k):
138 138 self._not_supported('ALTER TABLE DROP CONSTRAINT')
139 139
140 140
141 141 # TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index
142 142
143 143 class SQLiteDialect(ansisql.ANSIDialect):
144 144 columngenerator = SQLiteColumnGenerator
145 145 columndropper = SQLiteColumnDropper
146 146 schemachanger = SQLiteSchemaChanger
147 147 constraintgenerator = SQLiteConstraintGenerator
148 148 constraintdropper = SQLiteConstraintDropper
@@ -1,78 +1,78 b''
1 1 """
2 2 Module for visitor class mapping.
3 3 """
4 4 import sqlalchemy as sa
5 5
6 from migrate.changeset import ansisql
7 from migrate.changeset.databases import (sqlite,
6 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql
7 from rhodecode.lib.dbmigrate.migrate.changeset.databases import (sqlite,
8 8 postgres,
9 9 mysql,
10 10 oracle,
11 11 firebird)
12 12
13 13
14 14 # Map SA dialects to the corresponding Migrate extensions
15 15 DIALECTS = {
16 16 "default": ansisql.ANSIDialect,
17 17 "sqlite": sqlite.SQLiteDialect,
18 18 "postgres": postgres.PGDialect,
19 19 "postgresql": postgres.PGDialect,
20 20 "mysql": mysql.MySQLDialect,
21 21 "oracle": oracle.OracleDialect,
22 22 "firebird": firebird.FBDialect,
23 23 }
24 24
25 25
26 26 def get_engine_visitor(engine, name):
27 27 """
28 28 Get the visitor implementation for the given database engine.
29 29
30 30 :param engine: SQLAlchemy Engine
31 31 :param name: Name of the visitor
32 32 :type name: string
33 33 :type engine: Engine
34 34 :returns: visitor
35 35 """
36 36 # TODO: link to supported visitors
37 37 return get_dialect_visitor(engine.dialect, name)
38 38
39 39
40 40 def get_dialect_visitor(sa_dialect, name):
41 41 """
42 42 Get the visitor implementation for the given dialect.
43 43
44 44 Finds the visitor implementation based on the dialect class and
45 45 returns and instance initialized with the given name.
46 46
47 47 Binds dialect specific preparer to visitor.
48 48 """
49 49
50 50 # map sa dialect to migrate dialect and return visitor
51 51 sa_dialect_name = getattr(sa_dialect, 'name', 'default')
52 52 migrate_dialect_cls = DIALECTS[sa_dialect_name]
53 53 visitor = getattr(migrate_dialect_cls, name)
54 54
55 55 # bind preparer
56 56 visitor.preparer = sa_dialect.preparer(sa_dialect)
57 57
58 58 return visitor
59 59
60 60 def run_single_visitor(engine, visitorcallable, element,
61 61 connection=None, **kwargs):
62 62 """Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor`
63 63 with support for migrate visitors.
64 64 """
65 65 if connection is None:
66 66 conn = engine.contextual_connect(close_with_result=False)
67 67 else:
68 68 conn = connection
69 69 visitor = visitorcallable(engine.dialect, conn)
70 70 try:
71 71 if hasattr(element, '__migrate_visit_name__'):
72 72 fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__)
73 73 else:
74 74 fn = getattr(visitor, 'visit_' + element.__visit_name__)
75 75 fn(element, **kwargs)
76 76 finally:
77 77 if connection is None:
78 78 conn.close()
@@ -1,669 +1,669 b''
1 1 """
2 2 Schema module providing common schema operations.
3 3 """
4 4 import warnings
5 5
6 6 from UserDict import DictMixin
7 7
8 8 import sqlalchemy
9 9
10 10 from sqlalchemy.schema import ForeignKeyConstraint
11 11 from sqlalchemy.schema import UniqueConstraint
12 12
13 from migrate.exceptions import *
14 from migrate.changeset import SQLA_06
15 from migrate.changeset.databases.visitor import (get_engine_visitor,
13 from rhodecode.lib.dbmigrate.migrate.exceptions import *
14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
15 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor,
16 16 run_single_visitor)
17 17
18 18
19 19 __all__ = [
20 20 'create_column',
21 21 'drop_column',
22 22 'alter_column',
23 23 'rename_table',
24 24 'rename_index',
25 25 'ChangesetTable',
26 26 'ChangesetColumn',
27 27 'ChangesetIndex',
28 28 'ChangesetDefaultClause',
29 29 'ColumnDelta',
30 30 ]
31 31
32 32 DEFAULT_ALTER_METADATA = True
33 33
34 34
35 35 def create_column(column, table=None, *p, **kw):
36 36 """Create a column, given the table.
37 37
38 38 API to :meth:`ChangesetColumn.create`.
39 39 """
40 40 if table is not None:
41 41 return table.create_column(column, *p, **kw)
42 42 return column.create(*p, **kw)
43 43
44 44
45 45 def drop_column(column, table=None, *p, **kw):
46 46 """Drop a column, given the table.
47 47
48 48 API to :meth:`ChangesetColumn.drop`.
49 49 """
50 50 if table is not None:
51 51 return table.drop_column(column, *p, **kw)
52 52 return column.drop(*p, **kw)
53 53
54 54
55 55 def rename_table(table, name, engine=None, **kw):
56 56 """Rename a table.
57 57
58 58 If Table instance is given, engine is not used.
59 59
60 60 API to :meth:`ChangesetTable.rename`.
61 61
62 62 :param table: Table to be renamed.
63 63 :param name: New name for Table.
64 64 :param engine: Engine instance.
65 65 :type table: string or Table instance
66 66 :type name: string
67 67 :type engine: obj
68 68 """
69 69 table = _to_table(table, engine)
70 70 table.rename(name, **kw)
71 71
72 72
73 73 def rename_index(index, name, table=None, engine=None, **kw):
74 74 """Rename an index.
75 75
76 76 If Index instance is given,
77 77 table and engine are not used.
78 78
79 79 API to :meth:`ChangesetIndex.rename`.
80 80
81 81 :param index: Index to be renamed.
82 82 :param name: New name for index.
83 83 :param table: Table to which Index is reffered.
84 84 :param engine: Engine instance.
85 85 :type index: string or Index instance
86 86 :type name: string
87 87 :type table: string or Table instance
88 88 :type engine: obj
89 89 """
90 90 index = _to_index(index, table, engine)
91 91 index.rename(name, **kw)
92 92
93 93
94 94 def alter_column(*p, **k):
95 95 """Alter a column.
96 96
97 97 This is a helper function that creates a :class:`ColumnDelta` and
98 98 runs it.
99 99
100 100 :argument column:
101 101 The name of the column to be altered or a
102 102 :class:`ChangesetColumn` column representing it.
103 103
104 104 :param table:
105 105 A :class:`~sqlalchemy.schema.Table` or table name to
106 106 for the table where the column will be changed.
107 107
108 108 :param engine:
109 109 The :class:`~sqlalchemy.engine.base.Engine` to use for table
110 110 reflection and schema alterations.
111 111
112 112 :param alter_metadata:
113 113 If `True`, which is the default, the
114 114 :class:`~sqlalchemy.schema.Column` will also modified.
115 115 If `False`, the :class:`~sqlalchemy.schema.Column` will be left
116 116 as it was.
117 117
118 118 :returns: A :class:`ColumnDelta` instance representing the change.
119 119
120 120
121 121 """
122 122
123 123 k.setdefault('alter_metadata', DEFAULT_ALTER_METADATA)
124 124
125 125 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
126 126 k['table'] = p[0].table
127 127 if 'engine' not in k:
128 128 k['engine'] = k['table'].bind
129 129
130 130 # deprecation
131 131 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
132 132 warnings.warn(
133 133 "Passing a Column object to alter_column is deprecated."
134 134 " Just pass in keyword parameters instead.",
135 135 MigrateDeprecationWarning
136 136 )
137 137 engine = k['engine']
138 138 delta = ColumnDelta(*p, **k)
139 139
140 140 visitorcallable = get_engine_visitor(engine, 'schemachanger')
141 141 engine._run_visitor(visitorcallable, delta)
142 142
143 143 return delta
144 144
145 145
146 146 def _to_table(table, engine=None):
147 147 """Return if instance of Table, else construct new with metadata"""
148 148 if isinstance(table, sqlalchemy.Table):
149 149 return table
150 150
151 151 # Given: table name, maybe an engine
152 152 meta = sqlalchemy.MetaData()
153 153 if engine is not None:
154 154 meta.bind = engine
155 155 return sqlalchemy.Table(table, meta)
156 156
157 157
158 158 def _to_index(index, table=None, engine=None):
159 159 """Return if instance of Index, else construct new with metadata"""
160 160 if isinstance(index, sqlalchemy.Index):
161 161 return index
162 162
163 163 # Given: index name; table name required
164 164 table = _to_table(table, engine)
165 165 ret = sqlalchemy.Index(index)
166 166 ret.table = table
167 167 return ret
168 168
169 169
170 170 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
171 171 """Extracts the differences between two columns/column-parameters
172 172
173 173 May receive parameters arranged in several different ways:
174 174
175 175 * **current_column, new_column, \*p, \*\*kw**
176 176 Additional parameters can be specified to override column
177 177 differences.
178 178
179 179 * **current_column, \*p, \*\*kw**
180 180 Additional parameters alter current_column. Table name is extracted
181 181 from current_column object.
182 182 Name is changed to current_column.name from current_name,
183 183 if current_name is specified.
184 184
185 185 * **current_col_name, \*p, \*\*kw**
186 186 Table kw must specified.
187 187
188 188 :param table: Table at which current Column should be bound to.\
189 189 If table name is given, reflection will be used.
190 190 :type table: string or Table instance
191 191 :param alter_metadata: If True, it will apply changes to metadata.
192 192 :type alter_metadata: bool
193 193 :param metadata: If `alter_metadata` is true, \
194 194 metadata is used to reflect table names into
195 195 :type metadata: :class:`MetaData` instance
196 196 :param engine: When reflecting tables, either engine or metadata must \
197 197 be specified to acquire engine object.
198 198 :type engine: :class:`Engine` instance
199 199 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
200 200 `result_column` through :func:`dict` alike object.
201 201
202 202 * :class:`ColumnDelta`.result_column is altered column with new attributes
203 203
204 204 * :class:`ColumnDelta`.current_name is current name of column in db
205 205
206 206
207 207 """
208 208
209 209 # Column attributes that can be altered
210 210 diff_keys = ('name', 'type', 'primary_key', 'nullable',
211 211 'server_onupdate', 'server_default', 'autoincrement')
212 212 diffs = dict()
213 213 __visit_name__ = 'column'
214 214
215 215 def __init__(self, *p, **kw):
216 216 self.alter_metadata = kw.pop("alter_metadata", False)
217 217 self.meta = kw.pop("metadata", None)
218 218 self.engine = kw.pop("engine", None)
219 219
220 220 # Things are initialized differently depending on how many column
221 221 # parameters are given. Figure out how many and call the appropriate
222 222 # method.
223 223 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
224 224 # At least one column specified
225 225 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
226 226 # Two columns specified
227 227 diffs = self.compare_2_columns(*p, **kw)
228 228 else:
229 229 # Exactly one column specified
230 230 diffs = self.compare_1_column(*p, **kw)
231 231 else:
232 232 # Zero columns specified
233 233 if not len(p) or not isinstance(p[0], basestring):
234 234 raise ValueError("First argument must be column name")
235 235 diffs = self.compare_parameters(*p, **kw)
236 236
237 237 self.apply_diffs(diffs)
238 238
239 239 def __repr__(self):
240 240 return '<ColumnDelta altermetadata=%r, %s>' % (self.alter_metadata,
241 241 super(ColumnDelta, self).__repr__())
242 242
243 243 def __getitem__(self, key):
244 244 if key not in self.keys():
245 245 raise KeyError("No such diff key, available: %s" % self.diffs)
246 246 return getattr(self.result_column, key)
247 247
248 248 def __setitem__(self, key, value):
249 249 if key not in self.keys():
250 250 raise KeyError("No such diff key, available: %s" % self.diffs)
251 251 setattr(self.result_column, key, value)
252 252
253 253 def __delitem__(self, key):
254 254 raise NotImplementedError
255 255
256 256 def keys(self):
257 257 return self.diffs.keys()
258 258
259 259 def compare_parameters(self, current_name, *p, **k):
260 260 """Compares Column objects with reflection"""
261 261 self.table = k.pop('table')
262 262 self.result_column = self._table.c.get(current_name)
263 263 if len(p):
264 264 k = self._extract_parameters(p, k, self.result_column)
265 265 return k
266 266
267 267 def compare_1_column(self, col, *p, **k):
268 268 """Compares one Column object"""
269 269 self.table = k.pop('table', None)
270 270 if self.table is None:
271 271 self.table = col.table
272 272 self.result_column = col
273 273 if len(p):
274 274 k = self._extract_parameters(p, k, self.result_column)
275 275 return k
276 276
277 277 def compare_2_columns(self, old_col, new_col, *p, **k):
278 278 """Compares two Column objects"""
279 279 self.process_column(new_col)
280 280 self.table = k.pop('table', None)
281 281 # we cannot use bool() on table in SA06
282 282 if self.table is None:
283 283 self.table = old_col.table
284 284 if self.table is None:
285 285 new_col.table
286 286 self.result_column = old_col
287 287
288 288 # set differences
289 289 # leave out some stuff for later comp
290 290 for key in (set(self.diff_keys) - set(('type',))):
291 291 val = getattr(new_col, key, None)
292 292 if getattr(self.result_column, key, None) != val:
293 293 k.setdefault(key, val)
294 294
295 295 # inspect types
296 296 if not self.are_column_types_eq(self.result_column.type, new_col.type):
297 297 k.setdefault('type', new_col.type)
298 298
299 299 if len(p):
300 300 k = self._extract_parameters(p, k, self.result_column)
301 301 return k
302 302
303 303 def apply_diffs(self, diffs):
304 304 """Populate dict and column object with new values"""
305 305 self.diffs = diffs
306 306 for key in self.diff_keys:
307 307 if key in diffs:
308 308 setattr(self.result_column, key, diffs[key])
309 309
310 310 self.process_column(self.result_column)
311 311
312 312 # create an instance of class type if not yet
313 313 if 'type' in diffs and callable(self.result_column.type):
314 314 self.result_column.type = self.result_column.type()
315 315
316 316 # add column to the table
317 317 if self.table is not None and self.alter_metadata:
318 318 self.result_column.add_to_table(self.table)
319 319
320 320 def are_column_types_eq(self, old_type, new_type):
321 321 """Compares two types to be equal"""
322 322 ret = old_type.__class__ == new_type.__class__
323 323
324 324 # String length is a special case
325 325 if ret and isinstance(new_type, sqlalchemy.types.String):
326 326 ret = (getattr(old_type, 'length', None) == \
327 327 getattr(new_type, 'length', None))
328 328 return ret
329 329
330 330 def _extract_parameters(self, p, k, column):
331 331 """Extracts data from p and modifies diffs"""
332 332 p = list(p)
333 333 while len(p):
334 334 if isinstance(p[0], basestring):
335 335 k.setdefault('name', p.pop(0))
336 336 elif isinstance(p[0], sqlalchemy.types.AbstractType):
337 337 k.setdefault('type', p.pop(0))
338 338 elif callable(p[0]):
339 339 p[0] = p[0]()
340 340 else:
341 341 break
342 342
343 343 if len(p):
344 344 new_col = column.copy_fixed()
345 345 new_col._init_items(*p)
346 346 k = self.compare_2_columns(column, new_col, **k)
347 347 return k
348 348
349 349 def process_column(self, column):
350 350 """Processes default values for column"""
351 351 # XXX: this is a snippet from SA processing of positional parameters
352 352 if not SQLA_06 and column.args:
353 353 toinit = list(column.args)
354 354 else:
355 355 toinit = list()
356 356
357 357 if column.server_default is not None:
358 358 if isinstance(column.server_default, sqlalchemy.FetchedValue):
359 359 toinit.append(column.server_default)
360 360 else:
361 361 toinit.append(sqlalchemy.DefaultClause(column.server_default))
362 362 if column.server_onupdate is not None:
363 363 if isinstance(column.server_onupdate, FetchedValue):
364 364 toinit.append(column.server_default)
365 365 else:
366 366 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
367 367 for_update=True))
368 368 if toinit:
369 369 column._init_items(*toinit)
370 370
371 371 if not SQLA_06:
372 372 column.args = []
373 373
374 374 def _get_table(self):
375 375 return getattr(self, '_table', None)
376 376
377 377 def _set_table(self, table):
378 378 if isinstance(table, basestring):
379 379 if self.alter_metadata:
380 380 if not self.meta:
381 381 raise ValueError("metadata must be specified for table"
382 382 " reflection when using alter_metadata")
383 383 meta = self.meta
384 384 if self.engine:
385 385 meta.bind = self.engine
386 386 else:
387 387 if not self.engine and not self.meta:
388 388 raise ValueError("engine or metadata must be specified"
389 389 " to reflect tables")
390 390 if not self.engine:
391 391 self.engine = self.meta.bind
392 392 meta = sqlalchemy.MetaData(bind=self.engine)
393 393 self._table = sqlalchemy.Table(table, meta, autoload=True)
394 394 elif isinstance(table, sqlalchemy.Table):
395 395 self._table = table
396 396 if not self.alter_metadata:
397 397 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
398 398
399 399 def _get_result_column(self):
400 400 return getattr(self, '_result_column', None)
401 401
402 402 def _set_result_column(self, column):
403 403 """Set Column to Table based on alter_metadata evaluation."""
404 404 self.process_column(column)
405 405 if not hasattr(self, 'current_name'):
406 406 self.current_name = column.name
407 407 if self.alter_metadata:
408 408 self._result_column = column
409 409 else:
410 410 self._result_column = column.copy_fixed()
411 411
412 412 table = property(_get_table, _set_table)
413 413 result_column = property(_get_result_column, _set_result_column)
414 414
415 415
416 416 class ChangesetTable(object):
417 417 """Changeset extensions to SQLAlchemy tables."""
418 418
419 419 def create_column(self, column, *p, **kw):
420 420 """Creates a column.
421 421
422 422 The column parameter may be a column definition or the name of
423 423 a column in this table.
424 424
425 425 API to :meth:`ChangesetColumn.create`
426 426
427 427 :param column: Column to be created
428 428 :type column: Column instance or string
429 429 """
430 430 if not isinstance(column, sqlalchemy.Column):
431 431 # It's a column name
432 432 column = getattr(self.c, str(column))
433 433 column.create(table=self, *p, **kw)
434 434
435 435 def drop_column(self, column, *p, **kw):
436 436 """Drop a column, given its name or definition.
437 437
438 438 API to :meth:`ChangesetColumn.drop`
439 439
440 440 :param column: Column to be droped
441 441 :type column: Column instance or string
442 442 """
443 443 if not isinstance(column, sqlalchemy.Column):
444 444 # It's a column name
445 445 try:
446 446 column = getattr(self.c, str(column))
447 447 except AttributeError:
448 448 # That column isn't part of the table. We don't need
449 449 # its entire definition to drop the column, just its
450 450 # name, so create a dummy column with the same name.
451 451 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
452 452 column.drop(table=self, *p, **kw)
453 453
454 454 def rename(self, name, connection=None, **kwargs):
455 455 """Rename this table.
456 456
457 457 :param name: New name of the table.
458 458 :type name: string
459 459 :param alter_metadata: If True, table will be removed from metadata
460 460 :type alter_metadata: bool
461 461 :param connection: reuse connection istead of creating new one.
462 462 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
463 463 """
464 464 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
465 465 engine = self.bind
466 466 self.new_name = name
467 467 visitorcallable = get_engine_visitor(engine, 'schemachanger')
468 468 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
469 469
470 470 # Fix metadata registration
471 471 if self.alter_metadata:
472 472 self.name = name
473 473 self.deregister()
474 474 self._set_parent(self.metadata)
475 475
476 476 def _meta_key(self):
477 477 return sqlalchemy.schema._get_table_key(self.name, self.schema)
478 478
479 479 def deregister(self):
480 480 """Remove this table from its metadata"""
481 481 key = self._meta_key()
482 482 meta = self.metadata
483 483 if key in meta.tables:
484 484 del meta.tables[key]
485 485
486 486
487 487 class ChangesetColumn(object):
488 488 """Changeset extensions to SQLAlchemy columns."""
489 489
490 490 def alter(self, *p, **k):
491 491 """Makes a call to :func:`alter_column` for the column this
492 492 method is called on.
493 493 """
494 494 if 'table' not in k:
495 495 k['table'] = self.table
496 496 if 'engine' not in k:
497 497 k['engine'] = k['table'].bind
498 498 return alter_column(self, *p, **k)
499 499
500 500 def create(self, table=None, index_name=None, unique_name=None,
501 501 primary_key_name=None, populate_default=True, connection=None, **kwargs):
502 502 """Create this column in the database.
503 503
504 504 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
505 505 for most databases.
506 506
507 507 :param table: Table instance to create on.
508 508 :param index_name: Creates :class:`ChangesetIndex` on this column.
509 509 :param unique_name: Creates :class:\
510 510 `~migrate.changeset.constraint.UniqueConstraint` on this column.
511 511 :param primary_key_name: Creates :class:\
512 512 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
513 513 :param alter_metadata: If True, column will be added to table object.
514 514 :param populate_default: If True, created column will be \
515 515 populated with defaults
516 516 :param connection: reuse connection istead of creating new one.
517 517 :type table: Table instance
518 518 :type index_name: string
519 519 :type unique_name: string
520 520 :type primary_key_name: string
521 521 :type alter_metadata: bool
522 522 :type populate_default: bool
523 523 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
524 524
525 525 :returns: self
526 526 """
527 527 self.populate_default = populate_default
528 528 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
529 529 self.index_name = index_name
530 530 self.unique_name = unique_name
531 531 self.primary_key_name = primary_key_name
532 532 for cons in ('index_name', 'unique_name', 'primary_key_name'):
533 533 self._check_sanity_constraints(cons)
534 534
535 535 if self.alter_metadata:
536 536 self.add_to_table(table)
537 537 engine = self.table.bind
538 538 visitorcallable = get_engine_visitor(engine, 'columngenerator')
539 539 engine._run_visitor(visitorcallable, self, connection, **kwargs)
540 540
541 541 # TODO: reuse existing connection
542 542 if self.populate_default and self.default is not None:
543 543 stmt = table.update().values({self: engine._execute_default(self.default)})
544 544 engine.execute(stmt)
545 545
546 546 return self
547 547
548 548 def drop(self, table=None, connection=None, **kwargs):
549 549 """Drop this column from the database, leaving its table intact.
550 550
551 551 ``ALTER TABLE DROP COLUMN``, for most databases.
552 552
553 553 :param alter_metadata: If True, column will be removed from table object.
554 554 :type alter_metadata: bool
555 555 :param connection: reuse connection istead of creating new one.
556 556 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
557 557 """
558 558 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
559 559 if table is not None:
560 560 self.table = table
561 561 engine = self.table.bind
562 562 if self.alter_metadata:
563 563 self.remove_from_table(self.table, unset_table=False)
564 564 visitorcallable = get_engine_visitor(engine, 'columndropper')
565 565 engine._run_visitor(visitorcallable, self, connection, **kwargs)
566 566 if self.alter_metadata:
567 567 self.table = None
568 568 return self
569 569
570 570 def add_to_table(self, table):
571 571 if table is not None and self.table is None:
572 572 self._set_parent(table)
573 573
574 574 def _col_name_in_constraint(self, cons, name):
575 575 return False
576 576
577 577 def remove_from_table(self, table, unset_table=True):
578 578 # TODO: remove primary keys, constraints, etc
579 579 if unset_table:
580 580 self.table = None
581 581
582 582 to_drop = set()
583 583 for index in table.indexes:
584 584 columns = []
585 585 for col in index.columns:
586 586 if col.name != self.name:
587 587 columns.append(col)
588 588 if columns:
589 589 index.columns = columns
590 590 else:
591 591 to_drop.add(index)
592 592 table.indexes = table.indexes - to_drop
593 593
594 594 to_drop = set()
595 595 for cons in table.constraints:
596 596 # TODO: deal with other types of constraint
597 597 if isinstance(cons, (ForeignKeyConstraint,
598 598 UniqueConstraint)):
599 599 for col_name in cons.columns:
600 600 if not isinstance(col_name, basestring):
601 601 col_name = col_name.name
602 602 if self.name == col_name:
603 603 to_drop.add(cons)
604 604 table.constraints = table.constraints - to_drop
605 605
606 606 if table.c.contains_column(self):
607 607 table.c.remove(self)
608 608
609 609 # TODO: this is fixed in 0.6
610 610 def copy_fixed(self, **kw):
611 611 """Create a copy of this ``Column``, with all attributes."""
612 612 return sqlalchemy.Column(self.name, self.type, self.default,
613 613 key=self.key,
614 614 primary_key=self.primary_key,
615 615 nullable=self.nullable,
616 616 quote=self.quote,
617 617 index=self.index,
618 618 unique=self.unique,
619 619 onupdate=self.onupdate,
620 620 autoincrement=self.autoincrement,
621 621 server_default=self.server_default,
622 622 server_onupdate=self.server_onupdate,
623 623 *[c.copy(**kw) for c in self.constraints])
624 624
625 625 def _check_sanity_constraints(self, name):
626 626 """Check if constraints names are correct"""
627 627 obj = getattr(self, name)
628 628 if (getattr(self, name[:-5]) and not obj):
629 629 raise InvalidConstraintError("Column.create() accepts index_name,"
630 630 " primary_key_name and unique_name to generate constraints")
631 631 if not isinstance(obj, basestring) and obj is not None:
632 632 raise InvalidConstraintError(
633 633 "%s argument for column must be constraint name" % name)
634 634
635 635
636 636 class ChangesetIndex(object):
637 637 """Changeset extensions to SQLAlchemy Indexes."""
638 638
639 639 __visit_name__ = 'index'
640 640
641 641 def rename(self, name, connection=None, **kwargs):
642 642 """Change the name of an index.
643 643
644 644 :param name: New name of the Index.
645 645 :type name: string
646 646 :param alter_metadata: If True, Index object will be altered.
647 647 :type alter_metadata: bool
648 648 :param connection: reuse connection istead of creating new one.
649 649 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
650 650 """
651 651 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
652 652 engine = self.table.bind
653 653 self.new_name = name
654 654 visitorcallable = get_engine_visitor(engine, 'schemachanger')
655 655 engine._run_visitor(visitorcallable, self, connection, **kwargs)
656 656 if self.alter_metadata:
657 657 self.name = name
658 658
659 659
660 660 class ChangesetDefaultClause(object):
661 661 """Implements comparison between :class:`DefaultClause` instances"""
662 662
663 663 def __eq__(self, other):
664 664 if isinstance(other, self.__class__):
665 665 if self.arg == other.arg:
666 666 return True
667 667
668 668 def __ne__(self, other):
669 669 return not self.__eq__(other)
@@ -1,383 +1,383 b''
1 1 """
2 2 This module provides an external API to the versioning system.
3 3
4 4 .. versionchanged:: 0.6.0
5 5 :func:`migrate.versioning.api.test` and schema diff functions
6 6 changed order of positional arguments so all accept `url` and `repository`
7 7 as first arguments.
8 8
9 9 .. versionchanged:: 0.5.4
10 10 ``--preview_sql`` displays source file when using SQL scripts.
11 11 If Python script is used, it runs the action with mocked engine and
12 12 returns captured SQL statements.
13 13
14 14 .. versionchanged:: 0.5.4
15 15 Deprecated ``--echo`` parameter in favour of new
16 16 :func:`migrate.versioning.util.construct_engine` behavior.
17 17 """
18 18
19 19 # Dear migrate developers,
20 20 #
21 21 # please do not comment this module using sphinx syntax because its
22 22 # docstrings are presented as user help and most users cannot
23 23 # interpret sphinx annotated ReStructuredText.
24 24 #
25 25 # Thanks,
26 26 # Jan Dittberner
27 27
28 28 import sys
29 29 import inspect
30 30 import logging
31 31
32 from migrate import exceptions
33 from migrate.versioning import (repository, schema, version,
32 from rhodecode.lib.dbmigrate.migrate import exceptions
33 from rhodecode.lib.dbmigrate.migrate.versioning import (repository, schema, version,
34 34 script as script_) # command name conflict
35 from migrate.versioning.util import catch_known_errors, with_engine
35 from rhodecode.lib.dbmigrate.migrate.versioning.util import catch_known_errors, with_engine
36 36
37 37
38 38 log = logging.getLogger(__name__)
39 39 command_desc = {
40 40 'help': 'displays help on a given command',
41 41 'create': 'create an empty repository at the specified path',
42 42 'script': 'create an empty change Python script',
43 43 'script_sql': 'create empty change SQL scripts for given database',
44 44 'version': 'display the latest version available in a repository',
45 45 'db_version': 'show the current version of the repository under version control',
46 46 'source': 'display the Python code for a particular version in this repository',
47 47 'version_control': 'mark a database as under this repository\'s version control',
48 48 'upgrade': 'upgrade a database to a later version',
49 49 'downgrade': 'downgrade a database to an earlier version',
50 50 'drop_version_control': 'removes version control from a database',
51 51 'manage': 'creates a Python script that runs Migrate with a set of default values',
52 52 'test': 'performs the upgrade and downgrade command on the given database',
53 53 'compare_model_to_db': 'compare MetaData against the current database state',
54 54 'create_model': 'dump the current database as a Python model to stdout',
55 55 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
56 56 'update_db_from_model': 'modify the database to match the structure of the current MetaData',
57 57 }
58 58 __all__ = command_desc.keys()
59 59
60 60 Repository = repository.Repository
61 61 ControlledSchema = schema.ControlledSchema
62 62 VerNum = version.VerNum
63 63 PythonScript = script_.PythonScript
64 64 SqlScript = script_.SqlScript
65 65
66 66
67 67 # deprecated
68 68 def help(cmd=None, **opts):
69 69 """%prog help COMMAND
70 70
71 71 Displays help on a given command.
72 72 """
73 73 if cmd is None:
74 74 raise exceptions.UsageError(None)
75 75 try:
76 76 func = globals()[cmd]
77 77 except:
78 78 raise exceptions.UsageError(
79 79 "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
80 80 ret = func.__doc__
81 81 if sys.argv[0]:
82 82 ret = ret.replace('%prog', sys.argv[0])
83 83 return ret
84 84
85 85 @catch_known_errors
86 86 def create(repository, name, **opts):
87 87 """%prog create REPOSITORY_PATH NAME [--table=TABLE]
88 88
89 89 Create an empty repository at the specified path.
90 90
91 91 You can specify the version_table to be used; by default, it is
92 92 'migrate_version'. This table is created in all version-controlled
93 93 databases.
94 94 """
95 95 repo_path = Repository.create(repository, name, **opts)
96 96
97 97
98 98 @catch_known_errors
99 99 def script(description, repository, **opts):
100 100 """%prog script DESCRIPTION REPOSITORY_PATH
101 101
102 102 Create an empty change script using the next unused version number
103 103 appended with the given description.
104 104
105 105 For instance, manage.py script "Add initial tables" creates:
106 106 repository/versions/001_Add_initial_tables.py
107 107 """
108 108 repo = Repository(repository)
109 109 repo.create_script(description, **opts)
110 110
111 111
112 112 @catch_known_errors
113 113 def script_sql(database, repository, **opts):
114 114 """%prog script_sql DATABASE REPOSITORY_PATH
115 115
116 116 Create empty change SQL scripts for given DATABASE, where DATABASE
117 117 is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.)
118 118 or generic ('default').
119 119
120 120 For instance, manage.py script_sql postgres creates:
121 121 repository/versions/001_postgres_upgrade.sql and
122 122 repository/versions/001_postgres_postgres.sql
123 123 """
124 124 repo = Repository(repository)
125 125 repo.create_script_sql(database, **opts)
126 126
127 127
128 128 def version(repository, **opts):
129 129 """%prog version REPOSITORY_PATH
130 130
131 131 Display the latest version available in a repository.
132 132 """
133 133 repo = Repository(repository)
134 134 return repo.latest
135 135
136 136
137 137 @with_engine
138 138 def db_version(url, repository, **opts):
139 139 """%prog db_version URL REPOSITORY_PATH
140 140
141 141 Show the current version of the repository with the given
142 142 connection string, under version control of the specified
143 143 repository.
144 144
145 145 The url should be any valid SQLAlchemy connection string.
146 146 """
147 147 engine = opts.pop('engine')
148 148 schema = ControlledSchema(engine, repository)
149 149 return schema.version
150 150
151 151
152 152 def source(version, dest=None, repository=None, **opts):
153 153 """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
154 154
155 155 Display the Python code for a particular version in this
156 156 repository. Save it to the file at DESTINATION or, if omitted,
157 157 send to stdout.
158 158 """
159 159 if repository is None:
160 160 raise exceptions.UsageError("A repository must be specified")
161 161 repo = Repository(repository)
162 162 ret = repo.version(version).script().source()
163 163 if dest is not None:
164 164 dest = open(dest, 'w')
165 165 dest.write(ret)
166 166 dest.close()
167 167 ret = None
168 168 return ret
169 169
170 170
171 171 def upgrade(url, repository, version=None, **opts):
172 172 """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
173 173
174 174 Upgrade a database to a later version.
175 175
176 176 This runs the upgrade() function defined in your change scripts.
177 177
178 178 By default, the database is updated to the latest available
179 179 version. You may specify a version instead, if you wish.
180 180
181 181 You may preview the Python or SQL code to be executed, rather than
182 182 actually executing it, using the appropriate 'preview' option.
183 183 """
184 184 err = "Cannot upgrade a database of version %s to version %s. "\
185 185 "Try 'downgrade' instead."
186 186 return _migrate(url, repository, version, upgrade=True, err=err, **opts)
187 187
188 188
189 189 def downgrade(url, repository, version, **opts):
190 190 """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
191 191
192 192 Downgrade a database to an earlier version.
193 193
194 194 This is the reverse of upgrade; this runs the downgrade() function
195 195 defined in your change scripts.
196 196
197 197 You may preview the Python or SQL code to be executed, rather than
198 198 actually executing it, using the appropriate 'preview' option.
199 199 """
200 200 err = "Cannot downgrade a database of version %s to version %s. "\
201 201 "Try 'upgrade' instead."
202 202 return _migrate(url, repository, version, upgrade=False, err=err, **opts)
203 203
204 204 @with_engine
205 205 def test(url, repository, **opts):
206 206 """%prog test URL REPOSITORY_PATH [VERSION]
207 207
208 208 Performs the upgrade and downgrade option on the given
209 209 database. This is not a real test and may leave the database in a
210 210 bad state. You should therefore better run the test on a copy of
211 211 your database.
212 212 """
213 213 engine = opts.pop('engine')
214 214 repos = Repository(repository)
215 215 script = repos.version(None).script()
216 216
217 217 # Upgrade
218 218 log.info("Upgrading...")
219 219 script.run(engine, 1)
220 220 log.info("done")
221 221
222 222 log.info("Downgrading...")
223 223 script.run(engine, -1)
224 224 log.info("done")
225 225 log.info("Success")
226 226
227 227
228 228 @with_engine
229 229 def version_control(url, repository, version=None, **opts):
230 230 """%prog version_control URL REPOSITORY_PATH [VERSION]
231 231
232 232 Mark a database as under this repository's version control.
233 233
234 234 Once a database is under version control, schema changes should
235 235 only be done via change scripts in this repository.
236 236
237 237 This creates the table version_table in the database.
238 238
239 239 The url should be any valid SQLAlchemy connection string.
240 240
241 241 By default, the database begins at version 0 and is assumed to be
242 242 empty. If the database is not empty, you may specify a version at
243 243 which to begin instead. No attempt is made to verify this
244 244 version's correctness - the database schema is expected to be
245 245 identical to what it would be if the database were created from
246 246 scratch.
247 247 """
248 248 engine = opts.pop('engine')
249 249 ControlledSchema.create(engine, repository, version)
250 250
251 251
252 252 @with_engine
253 253 def drop_version_control(url, repository, **opts):
254 254 """%prog drop_version_control URL REPOSITORY_PATH
255 255
256 256 Removes version control from a database.
257 257 """
258 258 engine = opts.pop('engine')
259 259 schema = ControlledSchema(engine, repository)
260 260 schema.drop()
261 261
262 262
263 263 def manage(file, **opts):
264 264 """%prog manage FILENAME [VARIABLES...]
265 265
266 266 Creates a script that runs Migrate with a set of default values.
267 267
268 268 For example::
269 269
270 270 %prog manage manage.py --repository=/path/to/repository \
271 271 --url=sqlite:///project.db
272 272
273 273 would create the script manage.py. The following two commands
274 274 would then have exactly the same results::
275 275
276 276 python manage.py version
277 277 %prog version --repository=/path/to/repository
278 278 """
279 279 Repository.create_manage_file(file, **opts)
280 280
281 281
282 282 @with_engine
283 283 def compare_model_to_db(url, repository, model, **opts):
284 284 """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
285 285
286 286 Compare the current model (assumed to be a module level variable
287 287 of type sqlalchemy.MetaData) against the current database.
288 288
289 289 NOTE: This is EXPERIMENTAL.
290 290 """ # TODO: get rid of EXPERIMENTAL label
291 291 engine = opts.pop('engine')
292 292 return ControlledSchema.compare_model_to_db(engine, model, repository)
293 293
294 294
295 295 @with_engine
296 296 def create_model(url, repository, **opts):
297 297 """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
298 298
299 299 Dump the current database as a Python model to stdout.
300 300
301 301 NOTE: This is EXPERIMENTAL.
302 302 """ # TODO: get rid of EXPERIMENTAL label
303 303 engine = opts.pop('engine')
304 304 declarative = opts.get('declarative', False)
305 305 return ControlledSchema.create_model(engine, repository, declarative)
306 306
307 307
308 308 @catch_known_errors
309 309 @with_engine
310 310 def make_update_script_for_model(url, repository, oldmodel, model, **opts):
311 311 """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
312 312
313 313 Create a script changing the old Python model to the new (current)
314 314 Python model, sending to stdout.
315 315
316 316 NOTE: This is EXPERIMENTAL.
317 317 """ # TODO: get rid of EXPERIMENTAL label
318 318 engine = opts.pop('engine')
319 319 return PythonScript.make_update_script_for_model(
320 320 engine, oldmodel, model, repository, **opts)
321 321
322 322
323 323 @with_engine
324 324 def update_db_from_model(url, repository, model, **opts):
325 325 """%prog update_db_from_model URL REPOSITORY_PATH MODEL
326 326
327 327 Modify the database to match the structure of the current Python
328 328 model. This also sets the db_version number to the latest in the
329 329 repository.
330 330
331 331 NOTE: This is EXPERIMENTAL.
332 332 """ # TODO: get rid of EXPERIMENTAL label
333 333 engine = opts.pop('engine')
334 334 schema = ControlledSchema(engine, repository)
335 335 schema.update_db_from_model(model)
336 336
337 337 @with_engine
338 338 def _migrate(url, repository, version, upgrade, err, **opts):
339 339 engine = opts.pop('engine')
340 340 url = str(engine.url)
341 341 schema = ControlledSchema(engine, repository)
342 342 version = _migrate_version(schema, version, upgrade, err)
343 343
344 344 changeset = schema.changeset(version)
345 345 for ver, change in changeset:
346 346 nextver = ver + changeset.step
347 347 log.info('%s -> %s... ', ver, nextver)
348 348
349 349 if opts.get('preview_sql'):
350 350 if isinstance(change, PythonScript):
351 351 log.info(change.preview_sql(url, changeset.step, **opts))
352 352 elif isinstance(change, SqlScript):
353 353 log.info(change.source())
354 354
355 355 elif opts.get('preview_py'):
356 356 if not isinstance(change, PythonScript):
357 357 raise exceptions.UsageError("Python source can be only displayed"
358 358 " for python migration files")
359 359 source_ver = max(ver, nextver)
360 360 module = schema.repository.version(source_ver).script().module
361 361 funcname = upgrade and "upgrade" or "downgrade"
362 362 func = getattr(module, funcname)
363 363 log.info(inspect.getsource(func))
364 364 else:
365 365 schema.runchange(ver, change, changeset.step)
366 366 log.info('done')
367 367
368 368
369 369 def _migrate_version(schema, version, upgrade, err):
370 370 if version is None:
371 371 return version
372 372 # Version is specified: ensure we're upgrading in the right direction
373 373 # (current version < target version for upgrading; reverse for down)
374 374 version = VerNum(version)
375 375 cur = schema.version
376 376 if upgrade is not None:
377 377 if upgrade:
378 378 direction = cur <= version
379 379 else:
380 380 direction = cur >= version
381 381 if not direction:
382 382 raise exceptions.KnownError(err % (cur, version))
383 383 return version
@@ -1,27 +1,27 b''
1 1 """
2 2 Configuration parser module.
3 3 """
4 4
5 5 from ConfigParser import ConfigParser
6 6
7 from migrate.versioning.config import *
8 from migrate.versioning import pathed
7 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
8 from rhodecode.lib.dbmigrate.migrate.versioning import pathed
9 9
10 10
11 11 class Parser(ConfigParser):
12 12 """A project configuration file."""
13 13
14 14 def to_dict(self, sections=None):
15 15 """It's easier to access config values like dictionaries"""
16 16 return self._sections
17 17
18 18
19 19 class Config(pathed.Pathed, Parser):
20 20 """Configuration class."""
21 21
22 22 def __init__(self, path, *p, **k):
23 23 """Confirm the config file exists; read it."""
24 24 self.require_found(path)
25 25 pathed.Pathed.__init__(self, path)
26 26 Parser.__init__(self, *p, **k)
27 27 self.read(path)
@@ -1,254 +1,253 b''
1 1 """
2 2 Code to generate a Python model from a database or differences
3 3 between a model and database.
4 4
5 5 Some of this is borrowed heavily from the AutoCode project at:
6 6 http://code.google.com/p/sqlautocode/
7 7 """
8 8
9 9 import sys
10 10 import logging
11 11
12 12 import sqlalchemy
13 13
14 import migrate
15 import migrate.changeset
16
14 from rhodecode.lib.dbmigrate import migrate
15 from rhodecode.lib.dbmigrate.migrate import changeset
17 16
18 17 log = logging.getLogger(__name__)
19 18 HEADER = """
20 19 ## File autogenerated by genmodel.py
21 20
22 21 from sqlalchemy import *
23 22 meta = MetaData()
24 23 """
25 24
26 25 DECLARATIVE_HEADER = """
27 26 ## File autogenerated by genmodel.py
28 27
29 28 from sqlalchemy import *
30 29 from sqlalchemy.ext import declarative
31 30
32 31 Base = declarative.declarative_base()
33 32 """
34 33
35 34
36 35 class ModelGenerator(object):
37 36
38 37 def __init__(self, diff, engine, declarative=False):
39 38 self.diff = diff
40 39 self.engine = engine
41 40 self.declarative = declarative
42 41
43 42 def column_repr(self, col):
44 43 kwarg = []
45 44 if col.key != col.name:
46 45 kwarg.append('key')
47 46 if col.primary_key:
48 47 col.primary_key = True # otherwise it dumps it as 1
49 48 kwarg.append('primary_key')
50 49 if not col.nullable:
51 50 kwarg.append('nullable')
52 51 if col.onupdate:
53 52 kwarg.append('onupdate')
54 53 if col.default:
55 54 if col.primary_key:
56 55 # I found that PostgreSQL automatically creates a
57 56 # default value for the sequence, but let's not show
58 57 # that.
59 58 pass
60 59 else:
61 60 kwarg.append('default')
62 61 ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
63 62
64 63 # crs: not sure if this is good idea, but it gets rid of extra
65 64 # u''
66 65 name = col.name.encode('utf8')
67 66
68 67 type_ = col.type
69 68 for cls in col.type.__class__.__mro__:
70 69 if cls.__module__ == 'sqlalchemy.types' and \
71 70 not cls.__name__.isupper():
72 71 if cls is not type_.__class__:
73 72 type_ = cls()
74 73 break
75 74
76 75 data = {
77 76 'name': name,
78 77 'type': type_,
79 78 'constraints': ', '.join([repr(cn) for cn in col.constraints]),
80 79 'args': ks and ks or ''}
81 80
82 81 if data['constraints']:
83 82 if data['args']:
84 83 data['args'] = ',' + data['args']
85 84
86 85 if data['constraints'] or data['args']:
87 86 data['maybeComma'] = ','
88 87 else:
89 88 data['maybeComma'] = ''
90 89
91 90 commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
92 91 commonStuff = commonStuff.strip()
93 92 data['commonStuff'] = commonStuff
94 93 if self.declarative:
95 94 return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
96 95 else:
97 96 return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
98 97
99 98 def getTableDefn(self, table):
100 99 out = []
101 100 tableName = table.name
102 101 if self.declarative:
103 102 out.append("class %(table)s(Base):" % {'table': tableName})
104 103 out.append(" __tablename__ = '%(table)s'" % {'table': tableName})
105 104 for col in table.columns:
106 105 out.append(" %s" % self.column_repr(col))
107 106 else:
108 107 out.append("%(table)s = Table('%(table)s', meta," % \
109 108 {'table': tableName})
110 109 for col in table.columns:
111 110 out.append(" %s," % self.column_repr(col))
112 111 out.append(")")
113 112 return out
114 113
115 def _get_tables(self,missingA=False,missingB=False,modified=False):
114 def _get_tables(self, missingA=False, missingB=False, modified=False):
116 115 to_process = []
117 for bool_,names,metadata in (
118 (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
119 (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
120 (modified,self.diff.tables_different,self.diff.metadataA),
116 for bool_, names, metadata in (
117 (missingA, self.diff.tables_missing_from_A, self.diff.metadataB),
118 (missingB, self.diff.tables_missing_from_B, self.diff.metadataA),
119 (modified, self.diff.tables_different, self.diff.metadataA),
121 120 ):
122 121 if bool_:
123 122 for name in names:
124 123 yield metadata.tables.get(name)
125
124
126 125 def toPython(self):
127 126 """Assume database is current and model is empty."""
128 127 out = []
129 128 if self.declarative:
130 129 out.append(DECLARATIVE_HEADER)
131 130 else:
132 131 out.append(HEADER)
133 132 out.append("")
134 133 for table in self._get_tables(missingA=True):
135 134 out.extend(self.getTableDefn(table))
136 135 out.append("")
137 136 return '\n'.join(out)
138 137
139 138 def toUpgradeDowngradePython(self, indent=' '):
140 139 ''' Assume model is most current and database is out-of-date. '''
141 decls = ['from migrate.changeset import schema',
140 decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema',
142 141 'meta = MetaData()']
143 142 for table in self._get_tables(
144 missingA=True,missingB=True,modified=True
143 missingA=True, missingB=True, modified=True
145 144 ):
146 145 decls.extend(self.getTableDefn(table))
147 146
148 147 upgradeCommands, downgradeCommands = [], []
149 148 for tableName in self.diff.tables_missing_from_A:
150 149 upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
151 150 downgradeCommands.append("%(table)s.create()" % \
152 151 {'table': tableName})
153 152 for tableName in self.diff.tables_missing_from_B:
154 153 upgradeCommands.append("%(table)s.create()" % {'table': tableName})
155 154 downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
156 155
157 156 for tableName in self.diff.tables_different:
158 157 dbTable = self.diff.metadataB.tables[tableName]
159 158 missingInDatabase, missingInModel, diffDecl = \
160 159 self.diff.colDiffs[tableName]
161 160 for col in missingInDatabase:
162 161 upgradeCommands.append('%s.columns[%r].create()' % (
163 162 modelTable, col.name))
164 163 downgradeCommands.append('%s.columns[%r].drop()' % (
165 164 modelTable, col.name))
166 165 for col in missingInModel:
167 166 upgradeCommands.append('%s.columns[%r].drop()' % (
168 167 modelTable, col.name))
169 168 downgradeCommands.append('%s.columns[%r].create()' % (
170 169 modelTable, col.name))
171 170 for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
172 171 upgradeCommands.append(
173 172 'assert False, "Can\'t alter columns: %s:%s=>%s"',
174 173 modelTable, modelCol.name, databaseCol.name)
175 174 downgradeCommands.append(
176 175 'assert False, "Can\'t alter columns: %s:%s=>%s"',
177 176 modelTable, modelCol.name, databaseCol.name)
178 177 pre_command = ' meta.bind = migrate_engine'
179 178
180 179 return (
181 180 '\n'.join(decls),
182 181 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
183 182 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
184 183
185 def _db_can_handle_this_change(self,td):
184 def _db_can_handle_this_change(self, td):
186 185 if (td.columns_missing_from_B
187 186 and not td.columns_missing_from_A
188 187 and not td.columns_different):
189 188 # Even sqlite can handle this.
190 189 return True
191 190 else:
192 191 return not self.engine.url.drivername.startswith('sqlite')
193 192
194 193 def applyModel(self):
195 194 """Apply model to current database."""
196 195
197 196 meta = sqlalchemy.MetaData(self.engine)
198 197
199 198 for table in self._get_tables(missingA=True):
200 199 table = table.tometadata(meta)
201 200 table.drop()
202 201 for table in self._get_tables(missingB=True):
203 202 table = table.tometadata(meta)
204 203 table.create()
205 204 for modelTable in self._get_tables(modified=True):
206 205 tableName = modelTable.name
207 206 modelTable = modelTable.tometadata(meta)
208 207 dbTable = self.diff.metadataB.tables[tableName]
209 208
210 209 td = self.diff.tables_different[tableName]
211
210
212 211 if self._db_can_handle_this_change(td):
213
212
214 213 for col in td.columns_missing_from_B:
215 214 modelTable.columns[col].create()
216 215 for col in td.columns_missing_from_A:
217 216 dbTable.columns[col].drop()
218 217 # XXX handle column changes here.
219 218 else:
220 219 # Sqlite doesn't support drop column, so you have to
221 220 # do more: create temp table, copy data to it, drop
222 221 # old table, create new table, copy data back.
223 222 #
224 223 # I wonder if this is guaranteed to be unique?
225 224 tempName = '_temp_%s' % modelTable.name
226 225
227 226 def getCopyStatement():
228 227 preparer = self.engine.dialect.preparer
229 228 commonCols = []
230 229 for modelCol in modelTable.columns:
231 230 if modelCol.name in dbTable.columns:
232 231 commonCols.append(modelCol.name)
233 232 commonColsStr = ', '.join(commonCols)
234 233 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
235 234 (tableName, commonColsStr, commonColsStr, tempName)
236 235
237 236 # Move the data in one transaction, so that we don't
238 237 # leave the database in a nasty state.
239 238 connection = self.engine.connect()
240 239 trans = connection.begin()
241 240 try:
242 241 connection.execute(
243 242 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
244 243 (tempName, modelTable.name))
245 244 # make sure the drop takes place inside our
246 245 # transaction with the bind parameter
247 246 modelTable.drop(bind=connection)
248 247 modelTable.create(bind=connection)
249 248 connection.execute(getCopyStatement())
250 249 connection.execute('DROP TABLE %s' % tempName)
251 250 trans.commit()
252 251 except:
253 252 trans.rollback()
254 253 raise
@@ -1,75 +1,75 b''
1 1 """
2 2 A path/directory class.
3 3 """
4 4
5 5 import os
6 6 import shutil
7 7 import logging
8 8
9 from migrate import exceptions
10 from migrate.versioning.config import *
11 from migrate.versioning.util import KeyedInstance
9 from rhodecode.lib.dbmigrate.migrate import exceptions
10 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
11 from rhodecode.lib.dbmigrate.migrate.versioning.util import KeyedInstance
12 12
13 13
14 14 log = logging.getLogger(__name__)
15 15
16 16 class Pathed(KeyedInstance):
17 17 """
18 18 A class associated with a path/directory tree.
19 19
20 20 Only one instance of this class may exist for a particular file;
21 21 __new__ will return an existing instance if possible
22 22 """
23 23 parent = None
24 24
25 25 @classmethod
26 26 def _key(cls, path):
27 27 return str(path)
28 28
29 29 def __init__(self, path):
30 30 self.path = path
31 31 if self.__class__.parent is not None:
32 32 self._init_parent(path)
33 33
34 34 def _init_parent(self, path):
35 35 """Try to initialize this object's parent, if it has one"""
36 36 parent_path = self.__class__._parent_path(path)
37 37 self.parent = self.__class__.parent(parent_path)
38 38 log.debug("Getting parent %r:%r" % (self.__class__.parent, parent_path))
39 39 self.parent._init_child(path, self)
40 40
41 41 def _init_child(self, child, path):
42 42 """Run when a child of this object is initialized.
43 43
44 44 Parameters: the child object; the path to this object (its
45 45 parent)
46 46 """
47 47
48 48 @classmethod
49 49 def _parent_path(cls, path):
50 50 """
51 51 Fetch the path of this object's parent from this object's path.
52 52 """
53 53 # os.path.dirname(), but strip directories like files (like
54 54 # unix basename)
55 55 #
56 56 # Treat directories like files...
57 57 if path[-1] == '/':
58 58 path = path[:-1]
59 59 ret = os.path.dirname(path)
60 60 return ret
61 61
62 62 @classmethod
63 63 def require_notfound(cls, path):
64 64 """Ensures a given path does not already exist"""
65 65 if os.path.exists(path):
66 66 raise exceptions.PathFoundError(path)
67 67
68 68 @classmethod
69 69 def require_found(cls, path):
70 70 """Ensures a given path already exists"""
71 71 if not os.path.exists(path):
72 72 raise exceptions.PathNotFoundError(path)
73 73
74 74 def __str__(self):
75 75 return self.path
@@ -1,231 +1,231 b''
1 1 """
2 2 SQLAlchemy migrate repository management.
3 3 """
4 4 import os
5 5 import shutil
6 6 import string
7 7 import logging
8 8
9 9 from pkg_resources import resource_filename
10 10 from tempita import Template as TempitaTemplate
11 11
12 from migrate import exceptions
13 from migrate.versioning import version, pathed, cfgparse
14 from migrate.versioning.template import Template
15 from migrate.versioning.config import *
12 from rhodecode.lib.dbmigrate.migrate import exceptions
13 from rhodecode.lib.dbmigrate.migrate.versioning import version, pathed, cfgparse
14 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
15 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
16 16
17 17
18 18 log = logging.getLogger(__name__)
19 19
20 20 class Changeset(dict):
21 21 """A collection of changes to be applied to a database.
22 22
23 23 Changesets are bound to a repository and manage a set of
24 24 scripts from that repository.
25 25
26 26 Behaves like a dict, for the most part. Keys are ordered based on step value.
27 27 """
28 28
29 29 def __init__(self, start, *changes, **k):
30 30 """
31 31 Give a start version; step must be explicitly stated.
32 32 """
33 33 self.step = k.pop('step', 1)
34 34 self.start = version.VerNum(start)
35 35 self.end = self.start
36 36 for change in changes:
37 37 self.add(change)
38 38
39 39 def __iter__(self):
40 40 return iter(self.items())
41 41
42 42 def keys(self):
43 43 """
44 44 In a series of upgrades x -> y, keys are version x. Sorted.
45 45 """
46 46 ret = super(Changeset, self).keys()
47 47 # Reverse order if downgrading
48 48 ret.sort(reverse=(self.step < 1))
49 49 return ret
50 50
51 51 def values(self):
52 52 return [self[k] for k in self.keys()]
53 53
54 54 def items(self):
55 55 return zip(self.keys(), self.values())
56 56
57 57 def add(self, change):
58 58 """Add new change to changeset"""
59 59 key = self.end
60 60 self.end += self.step
61 61 self[key] = change
62 62
63 63 def run(self, *p, **k):
64 64 """Run the changeset scripts"""
65 65 for version, script in self:
66 66 script.run(*p, **k)
67 67
68 68
69 69 class Repository(pathed.Pathed):
70 70 """A project's change script repository"""
71 71
72 72 _config = 'migrate.cfg'
73 73 _versions = 'versions'
74 74
75 75 def __init__(self, path):
76 76 log.debug('Loading repository %s...' % path)
77 77 self.verify(path)
78 78 super(Repository, self).__init__(path)
79 79 self.config = cfgparse.Config(os.path.join(self.path, self._config))
80 80 self.versions = version.Collection(os.path.join(self.path,
81 81 self._versions))
82 82 log.debug('Repository %s loaded successfully' % path)
83 83 log.debug('Config: %r' % self.config.to_dict())
84 84
85 85 @classmethod
86 86 def verify(cls, path):
87 87 """
88 88 Ensure the target path is a valid repository.
89 89
90 90 :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
91 91 """
92 92 # Ensure the existence of required files
93 93 try:
94 94 cls.require_found(path)
95 95 cls.require_found(os.path.join(path, cls._config))
96 96 cls.require_found(os.path.join(path, cls._versions))
97 97 except exceptions.PathNotFoundError, e:
98 98 raise exceptions.InvalidRepositoryError(path)
99 99
100 100 @classmethod
101 101 def prepare_config(cls, tmpl_dir, name, options=None):
102 102 """
103 103 Prepare a project configuration file for a new project.
104 104
105 105 :param tmpl_dir: Path to Repository template
106 106 :param config_file: Name of the config file in Repository template
107 107 :param name: Repository name
108 108 :type tmpl_dir: string
109 109 :type config_file: string
110 110 :type name: string
111 111 :returns: Populated config file
112 112 """
113 113 if options is None:
114 114 options = {}
115 115 options.setdefault('version_table', 'migrate_version')
116 116 options.setdefault('repository_id', name)
117 117 options.setdefault('required_dbs', [])
118 118
119 119 tmpl = open(os.path.join(tmpl_dir, cls._config)).read()
120 120 ret = TempitaTemplate(tmpl).substitute(options)
121 121
122 122 # cleanup
123 123 del options['__template_name__']
124 124
125 125 return ret
126 126
127 127 @classmethod
128 128 def create(cls, path, name, **opts):
129 129 """Create a repository at a specified path"""
130 130 cls.require_notfound(path)
131 131 theme = opts.pop('templates_theme', None)
132 132 t_path = opts.pop('templates_path', None)
133 133
134 134 # Create repository
135 135 tmpl_dir = Template(t_path).get_repository(theme=theme)
136 136 shutil.copytree(tmpl_dir, path)
137 137
138 138 # Edit config defaults
139 139 config_text = cls.prepare_config(tmpl_dir, name, options=opts)
140 140 fd = open(os.path.join(path, cls._config), 'w')
141 141 fd.write(config_text)
142 142 fd.close()
143 143
144 144 opts['repository_name'] = name
145 145
146 146 # Create a management script
147 147 manager = os.path.join(path, 'manage.py')
148 148 Repository.create_manage_file(manager, templates_theme=theme,
149 149 templates_path=t_path, **opts)
150 150
151 151 return cls(path)
152 152
153 153 def create_script(self, description, **k):
154 154 """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
155 155 self.versions.create_new_python_version(description, **k)
156 156
157 157 def create_script_sql(self, database, **k):
158 158 """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
159 159 self.versions.create_new_sql_version(database, **k)
160 160
161 161 @property
162 162 def latest(self):
163 163 """API to :attr:`migrate.versioning.version.Collection.latest`"""
164 164 return self.versions.latest
165 165
166 166 @property
167 167 def version_table(self):
168 168 """Returns version_table name specified in config"""
169 169 return self.config.get('db_settings', 'version_table')
170 170
171 171 @property
172 172 def id(self):
173 173 """Returns repository id specified in config"""
174 174 return self.config.get('db_settings', 'repository_id')
175 175
176 176 def version(self, *p, **k):
177 177 """API to :attr:`migrate.versioning.version.Collection.version`"""
178 178 return self.versions.version(*p, **k)
179 179
180 180 @classmethod
181 181 def clear(cls):
182 182 # TODO: deletes repo
183 183 super(Repository, cls).clear()
184 184 version.Collection.clear()
185 185
186 186 def changeset(self, database, start, end=None):
187 187 """Create a changeset to migrate this database from ver. start to end/latest.
188 188
189 189 :param database: name of database to generate changeset
190 190 :param start: version to start at
191 191 :param end: version to end at (latest if None given)
192 192 :type database: string
193 193 :type start: int
194 194 :type end: int
195 195 :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
196 196 """
197 197 start = version.VerNum(start)
198 198
199 199 if end is None:
200 200 end = self.latest
201 201 else:
202 202 end = version.VerNum(end)
203 203
204 204 if start <= end:
205 205 step = 1
206 206 range_mod = 1
207 207 op = 'upgrade'
208 208 else:
209 209 step = -1
210 210 range_mod = 0
211 211 op = 'downgrade'
212 212
213 213 versions = range(start + range_mod, end + range_mod, step)
214 214 changes = [self.version(v).script(database, op) for v in versions]
215 215 ret = Changeset(start, step=step, *changes)
216 216 return ret
217 217
218 218 @classmethod
219 219 def create_manage_file(cls, file_, **opts):
220 220 """Create a project management script (manage.py)
221 221
222 222 :param file_: Destination file to be written
223 223 :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
224 224 """
225 225 mng_file = Template(opts.pop('templates_path', None))\
226 226 .get_manage(theme=opts.pop('templates_theme', None))
227 227
228 228 tmpl = open(mng_file).read()
229 229 fd = open(file_, 'w')
230 230 fd.write(TempitaTemplate(tmpl).substitute(opts))
231 231 fd.close()
@@ -1,213 +1,213 b''
1 1 """
2 2 Database schema version management.
3 3 """
4 4 import sys
5 5 import logging
6 6
7 7 from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
8 8 create_engine)
9 9 from sqlalchemy.sql import and_
10 10 from sqlalchemy import exceptions as sa_exceptions
11 11 from sqlalchemy.sql import bindparam
12 12
13 from migrate import exceptions
14 from migrate.versioning import genmodel, schemadiff
15 from migrate.versioning.repository import Repository
16 from migrate.versioning.util import load_model
17 from migrate.versioning.version import VerNum
13 from rhodecode.lib.dbmigrate.migrate import exceptions
14 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
15 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
16 from rhodecode.lib.dbmigrate.migrate.versioning.util import load_model
17 from rhodecode.lib.dbmigrate.migrate.versioning.version import VerNum
18 18
19 19
20 20 log = logging.getLogger(__name__)
21 21
22 22 class ControlledSchema(object):
23 23 """A database under version control"""
24 24
25 25 def __init__(self, engine, repository):
26 26 if isinstance(repository, basestring):
27 27 repository = Repository(repository)
28 28 self.engine = engine
29 29 self.repository = repository
30 30 self.meta = MetaData(engine)
31 31 self.load()
32 32
33 33 def __eq__(self, other):
34 34 """Compare two schemas by repositories and versions"""
35 35 return (self.repository is other.repository \
36 36 and self.version == other.version)
37 37
38 38 def load(self):
39 39 """Load controlled schema version info from DB"""
40 40 tname = self.repository.version_table
41 41 try:
42 42 if not hasattr(self, 'table') or self.table is None:
43 43 self.table = Table(tname, self.meta, autoload=True)
44 44
45 45 result = self.engine.execute(self.table.select(
46 46 self.table.c.repository_id == str(self.repository.id)))
47 47
48 48 data = list(result)[0]
49 49 except:
50 50 cls, exc, tb = sys.exc_info()
51 51 raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
52 52
53 53 self.version = data['version']
54 54 return data
55 55
56 56 def drop(self):
57 57 """
58 58 Remove version control from a database.
59 59 """
60 60 try:
61 61 self.table.drop()
62 62 except (sa_exceptions.SQLError):
63 63 raise exceptions.DatabaseNotControlledError(str(self.table))
64 64
65 65 def changeset(self, version=None):
66 66 """API to Changeset creation.
67 67
68 68 Uses self.version for start version and engine.name
69 69 to get database name.
70 70 """
71 71 database = self.engine.name
72 72 start_ver = self.version
73 73 changeset = self.repository.changeset(database, start_ver, version)
74 74 return changeset
75 75
76 76 def runchange(self, ver, change, step):
77 77 startver = ver
78 78 endver = ver + step
79 79 # Current database version must be correct! Don't run if corrupt!
80 80 if self.version != startver:
81 81 raise exceptions.InvalidVersionError("%s is not %s" % \
82 82 (self.version, startver))
83 83 # Run the change
84 84 change.run(self.engine, step)
85 85
86 86 # Update/refresh database version
87 87 self.update_repository_table(startver, endver)
88 88 self.load()
89 89
90 90 def update_repository_table(self, startver, endver):
91 91 """Update version_table with new information"""
92 92 update = self.table.update(and_(self.table.c.version == int(startver),
93 93 self.table.c.repository_id == str(self.repository.id)))
94 94 self.engine.execute(update, version=int(endver))
95 95
96 96 def upgrade(self, version=None):
97 97 """
98 98 Upgrade (or downgrade) to a specified version, or latest version.
99 99 """
100 100 changeset = self.changeset(version)
101 101 for ver, change in changeset:
102 102 self.runchange(ver, change, changeset.step)
103 103
104 104 def update_db_from_model(self, model):
105 105 """
106 106 Modify the database to match the structure of the current Python model.
107 107 """
108 108 model = load_model(model)
109 109
110 110 diff = schemadiff.getDiffOfModelAgainstDatabase(
111 111 model, self.engine, excludeTables=[self.repository.version_table]
112 112 )
113 113 genmodel.ModelGenerator(diff,self.engine).applyModel()
114 114
115 115 self.update_repository_table(self.version, int(self.repository.latest))
116 116
117 117 self.load()
118 118
119 119 @classmethod
120 120 def create(cls, engine, repository, version=None):
121 121 """
122 122 Declare a database to be under a repository's version control.
123 123
124 124 :raises: :exc:`DatabaseAlreadyControlledError`
125 125 :returns: :class:`ControlledSchema`
126 126 """
127 127 # Confirm that the version # is valid: positive, integer,
128 128 # exists in repos
129 129 if isinstance(repository, basestring):
130 130 repository = Repository(repository)
131 131 version = cls._validate_version(repository, version)
132 132 table = cls._create_table_version(engine, repository, version)
133 133 # TODO: history table
134 134 # Load repository information and return
135 135 return cls(engine, repository)
136 136
137 137 @classmethod
138 138 def _validate_version(cls, repository, version):
139 139 """
140 140 Ensures this is a valid version number for this repository.
141 141
142 142 :raises: :exc:`InvalidVersionError` if invalid
143 143 :return: valid version number
144 144 """
145 145 if version is None:
146 146 version = 0
147 147 try:
148 148 version = VerNum(version) # raises valueerror
149 149 if version < 0 or version > repository.latest:
150 150 raise ValueError()
151 151 except ValueError:
152 152 raise exceptions.InvalidVersionError(version)
153 153 return version
154 154
155 155 @classmethod
156 156 def _create_table_version(cls, engine, repository, version):
157 157 """
158 158 Creates the versioning table in a database.
159 159
160 160 :raises: :exc:`DatabaseAlreadyControlledError`
161 161 """
162 162 # Create tables
163 163 tname = repository.version_table
164 164 meta = MetaData(engine)
165 165
166 166 table = Table(
167 167 tname, meta,
168 168 Column('repository_id', String(250), primary_key=True),
169 169 Column('repository_path', Text),
170 170 Column('version', Integer), )
171 171
172 172 # there can be multiple repositories/schemas in the same db
173 173 if not table.exists():
174 174 table.create()
175 175
176 176 # test for existing repository_id
177 177 s = table.select(table.c.repository_id == bindparam("repository_id"))
178 178 result = engine.execute(s, repository_id=repository.id)
179 179 if result.fetchone():
180 180 raise exceptions.DatabaseAlreadyControlledError
181 181
182 182 # Insert data
183 183 engine.execute(table.insert().values(
184 184 repository_id=repository.id,
185 185 repository_path=repository.path,
186 186 version=int(version)))
187 187 return table
188 188
189 189 @classmethod
190 190 def compare_model_to_db(cls, engine, model, repository):
191 191 """
192 192 Compare the current model against the current database.
193 193 """
194 194 if isinstance(repository, basestring):
195 195 repository = Repository(repository)
196 196 model = load_model(model)
197 197
198 198 diff = schemadiff.getDiffOfModelAgainstDatabase(
199 199 model, engine, excludeTables=[repository.version_table])
200 200 return diff
201 201
202 202 @classmethod
203 203 def create_model(cls, engine, repository, declarative=False):
204 204 """
205 205 Dump the current database as a Python model.
206 206 """
207 207 if isinstance(repository, basestring):
208 208 repository = Repository(repository)
209 209
210 210 diff = schemadiff.getDiffOfModelAgainstDatabase(
211 211 MetaData(), engine, excludeTables=[repository.version_table]
212 212 )
213 213 return genmodel.ModelGenerator(diff, engine, declarative).toPython()
@@ -1,285 +1,285 b''
1 1 """
2 2 Schema differencing support.
3 3 """
4 4
5 5 import logging
6 6 import sqlalchemy
7 7
8 from migrate.changeset import SQLA_06
8 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
9 9 from sqlalchemy.types import Float
10 10
11 11 log = logging.getLogger(__name__)
12 12
13 13 def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None):
14 14 """
15 15 Return differences of model against database.
16 16
17 17 :return: object which will evaluate to :keyword:`True` if there \
18 18 are differences else :keyword:`False`.
19 19 """
20 20 return SchemaDiff(metadata,
21 21 sqlalchemy.MetaData(engine, reflect=True),
22 22 labelA='model',
23 23 labelB='database',
24 24 excludeTables=excludeTables)
25 25
26 26
27 27 def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None):
28 28 """
29 29 Return differences of model against another model.
30 30
31 31 :return: object which will evaluate to :keyword:`True` if there \
32 32 are differences else :keyword:`False`.
33 33 """
34 34 return SchemaDiff(metadataA, metadataB, excludeTables)
35 35
36 36
37 37 class ColDiff(object):
38 38 """
39 39 Container for differences in one :class:`~sqlalchemy.schema.Column`
40 40 between two :class:`~sqlalchemy.schema.Table` instances, ``A``
41 41 and ``B``.
42 42
43 43 .. attribute:: col_A
44 44
45 45 The :class:`~sqlalchemy.schema.Column` object for A.
46 46
47 47 .. attribute:: col_B
48 48
49 49 The :class:`~sqlalchemy.schema.Column` object for B.
50 50
51 51 .. attribute:: type_A
52 52
53 53 The most generic type of the :class:`~sqlalchemy.schema.Column`
54 54 object in A.
55 55
56 56 .. attribute:: type_B
57 57
58 58 The most generic type of the :class:`~sqlalchemy.schema.Column`
59 59 object in A.
60 60
61 61 """
62 62
63 63 diff = False
64 64
65 65 def __init__(self,col_A,col_B):
66 66 self.col_A = col_A
67 67 self.col_B = col_B
68 68
69 69 self.type_A = col_A.type
70 70 self.type_B = col_B.type
71 71
72 72 self.affinity_A = self.type_A._type_affinity
73 73 self.affinity_B = self.type_B._type_affinity
74 74
75 75 if self.affinity_A is not self.affinity_B:
76 76 self.diff = True
77 77 return
78 78
79 79 if isinstance(self.type_A,Float) or isinstance(self.type_B,Float):
80 80 if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)):
81 81 self.diff=True
82 82 return
83 83
84 84 for attr in ('precision','scale','length'):
85 85 A = getattr(self.type_A,attr,None)
86 86 B = getattr(self.type_B,attr,None)
87 87 if not (A is None or B is None) and A!=B:
88 88 self.diff=True
89 89 return
90 90
91 91 def __nonzero__(self):
92 92 return self.diff
93 93
94 94 class TableDiff(object):
95 95 """
96 96 Container for differences in one :class:`~sqlalchemy.schema.Table`
97 97 between two :class:`~sqlalchemy.schema.MetaData` instances, ``A``
98 98 and ``B``.
99 99
100 100 .. attribute:: columns_missing_from_A
101 101
102 102 A sequence of column names that were found in B but weren't in
103 103 A.
104 104
105 105 .. attribute:: columns_missing_from_B
106 106
107 107 A sequence of column names that were found in A but weren't in
108 108 B.
109 109
110 110 .. attribute:: columns_different
111 111
112 112 A dictionary containing information about columns that were
113 113 found to be different.
114 114 It maps column names to a :class:`ColDiff` objects describing the
115 115 differences found.
116 116 """
117 117 __slots__ = (
118 118 'columns_missing_from_A',
119 119 'columns_missing_from_B',
120 120 'columns_different',
121 121 )
122 122
123 123 def __nonzero__(self):
124 124 return bool(
125 125 self.columns_missing_from_A or
126 126 self.columns_missing_from_B or
127 127 self.columns_different
128 128 )
129 129
130 130 class SchemaDiff(object):
131 131 """
132 132 Compute the difference between two :class:`~sqlalchemy.schema.MetaData`
133 133 objects.
134 134
135 135 The string representation of a :class:`SchemaDiff` will summarise
136 136 the changes found between the two
137 137 :class:`~sqlalchemy.schema.MetaData` objects.
138 138
139 139 The length of a :class:`SchemaDiff` will give the number of
140 140 changes found, enabling it to be used much like a boolean in
141 141 expressions.
142 142
143 143 :param metadataA:
144 144 First :class:`~sqlalchemy.schema.MetaData` to compare.
145 145
146 146 :param metadataB:
147 147 Second :class:`~sqlalchemy.schema.MetaData` to compare.
148 148
149 149 :param labelA:
150 150 The label to use in messages about the first
151 151 :class:`~sqlalchemy.schema.MetaData`.
152 152
153 153 :param labelB:
154 154 The label to use in messages about the second
155 155 :class:`~sqlalchemy.schema.MetaData`.
156 156
157 157 :param excludeTables:
158 158 A sequence of table names to exclude.
159 159
160 160 .. attribute:: tables_missing_from_A
161 161
162 162 A sequence of table names that were found in B but weren't in
163 163 A.
164 164
165 165 .. attribute:: tables_missing_from_B
166 166
167 167 A sequence of table names that were found in A but weren't in
168 168 B.
169 169
170 170 .. attribute:: tables_different
171 171
172 172 A dictionary containing information about tables that were found
173 173 to be different.
174 174 It maps table names to a :class:`TableDiff` objects describing the
175 175 differences found.
176 176 """
177 177
178 178 def __init__(self,
179 179 metadataA, metadataB,
180 180 labelA='metadataA',
181 181 labelB='metadataB',
182 182 excludeTables=None):
183 183
184 184 self.metadataA, self.metadataB = metadataA, metadataB
185 185 self.labelA, self.labelB = labelA, labelB
186 186 self.label_width = max(len(labelA),len(labelB))
187 187 excludeTables = set(excludeTables or [])
188 188
189 189 A_table_names = set(metadataA.tables.keys())
190 190 B_table_names = set(metadataB.tables.keys())
191 191
192 192 self.tables_missing_from_A = sorted(
193 193 B_table_names - A_table_names - excludeTables
194 194 )
195 195 self.tables_missing_from_B = sorted(
196 196 A_table_names - B_table_names - excludeTables
197 197 )
198 198
199 199 self.tables_different = {}
200 200 for table_name in A_table_names.intersection(B_table_names):
201 201
202 202 td = TableDiff()
203 203
204 204 A_table = metadataA.tables[table_name]
205 205 B_table = metadataB.tables[table_name]
206 206
207 207 A_column_names = set(A_table.columns.keys())
208 208 B_column_names = set(B_table.columns.keys())
209 209
210 210 td.columns_missing_from_A = sorted(
211 211 B_column_names - A_column_names
212 212 )
213 213
214 214 td.columns_missing_from_B = sorted(
215 215 A_column_names - B_column_names
216 216 )
217 217
218 218 td.columns_different = {}
219 219
220 220 for col_name in A_column_names.intersection(B_column_names):
221 221
222 222 cd = ColDiff(
223 223 A_table.columns.get(col_name),
224 224 B_table.columns.get(col_name)
225 225 )
226 226
227 227 if cd:
228 228 td.columns_different[col_name]=cd
229 229
230 230 # XXX - index and constraint differences should
231 231 # be checked for here
232 232
233 233 if td:
234 234 self.tables_different[table_name]=td
235 235
236 236 def __str__(self):
237 237 ''' Summarize differences. '''
238 238 out = []
239 239 column_template =' %%%is: %%r' % self.label_width
240 240
241 241 for names,label in (
242 242 (self.tables_missing_from_A,self.labelA),
243 243 (self.tables_missing_from_B,self.labelB),
244 244 ):
245 245 if names:
246 246 out.append(
247 247 ' tables missing from %s: %s' % (
248 248 label,', '.join(sorted(names))
249 249 )
250 250 )
251 251
252 252 for name,td in sorted(self.tables_different.items()):
253 253 out.append(
254 254 ' table with differences: %s' % name
255 255 )
256 256 for names,label in (
257 257 (td.columns_missing_from_A,self.labelA),
258 258 (td.columns_missing_from_B,self.labelB),
259 259 ):
260 260 if names:
261 261 out.append(
262 262 ' %s missing these columns: %s' % (
263 263 label,', '.join(sorted(names))
264 264 )
265 265 )
266 266 for name,cd in td.columns_different.items():
267 267 out.append(' column with differences: %s' % name)
268 268 out.append(column_template % (self.labelA,cd.col_A))
269 269 out.append(column_template % (self.labelB,cd.col_B))
270 270
271 271 if out:
272 272 out.insert(0, 'Schema diffs:')
273 273 return '\n'.join(out)
274 274 else:
275 275 return 'No schema diffs'
276 276
277 277 def __len__(self):
278 278 """
279 279 Used in bool evaluation, return of 0 means no diffs.
280 280 """
281 281 return (
282 282 len(self.tables_missing_from_A) +
283 283 len(self.tables_missing_from_B) +
284 284 len(self.tables_different)
285 285 )
@@ -1,6 +1,6 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3
4 from migrate.versioning.script.base import BaseScript
5 from migrate.versioning.script.py import PythonScript
6 from migrate.versioning.script.sql import SqlScript
4 from rhodecode.lib.dbmigrate.migrate.versioning.script.base import BaseScript
5 from rhodecode.lib.dbmigrate.migrate.versioning.script.py import PythonScript
6 from rhodecode.lib.dbmigrate.migrate.versioning.script.sql import SqlScript
@@ -1,57 +1,57 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3 import logging
4 4
5 from migrate import exceptions
6 from migrate.versioning.config import operations
7 from migrate.versioning import pathed
5 from rhodecode.lib.dbmigrate.migrate import exceptions
6 from rhodecode.lib.dbmigrate.migrate.versioning.config import operations
7 from rhodecode.lib.dbmigrate.migrate.versioning import pathed
8 8
9 9
10 10 log = logging.getLogger(__name__)
11 11
12 12 class BaseScript(pathed.Pathed):
13 13 """Base class for other types of scripts.
14 14 All scripts have the following properties:
15 15
16 16 source (script.source())
17 17 The source code of the script
18 18 version (script.version())
19 19 The version number of the script
20 20 operations (script.operations())
21 21 The operations defined by the script: upgrade(), downgrade() or both.
22 22 Returns a tuple of operations.
23 23 Can also check for an operation with ex. script.operation(Script.ops.up)
24 24 """ # TODO: sphinxfy this and implement it correctly
25 25
26 26 def __init__(self, path):
27 27 log.debug('Loading script %s...' % path)
28 28 self.verify(path)
29 29 super(BaseScript, self).__init__(path)
30 30 log.debug('Script %s loaded successfully' % path)
31 31
32 32 @classmethod
33 33 def verify(cls, path):
34 34 """Ensure this is a valid script
35 35 This version simply ensures the script file's existence
36 36
37 37 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
38 38 """
39 39 try:
40 40 cls.require_found(path)
41 41 except:
42 42 raise exceptions.InvalidScriptError(path)
43 43
44 44 def source(self):
45 45 """:returns: source code of the script.
46 46 :rtype: string
47 47 """
48 48 fd = open(self.path)
49 49 ret = fd.read()
50 50 fd.close()
51 51 return ret
52 52
53 53 def run(self, engine):
54 54 """Core of each BaseScript subclass.
55 55 This method executes the script.
56 56 """
57 57 raise NotImplementedError()
@@ -1,159 +1,159 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3
4 4 import shutil
5 5 import warnings
6 6 import logging
7 7 from StringIO import StringIO
8 8
9 import migrate
10 from migrate.versioning import genmodel, schemadiff
11 from migrate.versioning.config import operations
12 from migrate.versioning.template import Template
13 from migrate.versioning.script import base
14 from migrate.versioning.util import import_path, load_model, with_engine
15 from migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
9 from rhodecode.lib.dbmigrate import migrate
10 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
11 from rhodecode.lib.dbmigrate.migrate.versioning.config import operations
12 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
13 from rhodecode.lib.dbmigrate.migrate.versioning.script import base
14 from rhodecode.lib.dbmigrate.migrate.versioning.util import import_path, load_model, with_engine
15 from rhodecode.lib.dbmigrate.migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
16 16
17 17 log = logging.getLogger(__name__)
18 18 __all__ = ['PythonScript']
19 19
20 20
21 21 class PythonScript(base.BaseScript):
22 22 """Base for Python scripts"""
23 23
24 24 @classmethod
25 25 def create(cls, path, **opts):
26 26 """Create an empty migration script at specified path
27 27
28 28 :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
29 29 cls.require_notfound(path)
30 30
31 31 src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
32 32 shutil.copy(src, path)
33 33
34 34 return cls(path)
35 35
36 36 @classmethod
37 37 def make_update_script_for_model(cls, engine, oldmodel,
38 38 model, repository, **opts):
39 39 """Create a migration script based on difference between two SA models.
40 40
41 41 :param repository: path to migrate repository
42 42 :param oldmodel: dotted.module.name:SAClass or SAClass object
43 43 :param model: dotted.module.name:SAClass or SAClass object
44 44 :param engine: SQLAlchemy engine
45 45 :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
46 46 :type oldmodel: string or Class
47 47 :type model: string or Class
48 48 :type engine: Engine instance
49 49 :returns: Upgrade / Downgrade script
50 50 :rtype: string
51 51 """
52
52
53 53 if isinstance(repository, basestring):
54 54 # oh dear, an import cycle!
55 from migrate.versioning.repository import Repository
55 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
56 56 repository = Repository(repository)
57 57
58 58 oldmodel = load_model(oldmodel)
59 59 model = load_model(model)
60 60
61 61 # Compute differences.
62 62 diff = schemadiff.getDiffOfModelAgainstModel(
63 63 oldmodel,
64 64 model,
65 65 excludeTables=[repository.version_table])
66 66 # TODO: diff can be False (there is no difference?)
67 67 decls, upgradeCommands, downgradeCommands = \
68 genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython()
68 genmodel.ModelGenerator(diff, engine).toUpgradeDowngradePython()
69 69
70 70 # Store differences into file.
71 71 src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
72 72 f = open(src)
73 73 contents = f.read()
74 74 f.close()
75 75
76 76 # generate source
77 77 search = 'def upgrade(migrate_engine):'
78 78 contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
79 79 if upgradeCommands:
80 80 contents = contents.replace(' pass', upgradeCommands, 1)
81 81 if downgradeCommands:
82 82 contents = contents.replace(' pass', downgradeCommands, 1)
83 83 return contents
84 84
85 85 @classmethod
86 86 def verify_module(cls, path):
87 87 """Ensure path is a valid script
88 88
89 89 :param path: Script location
90 90 :type path: string
91 91 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
92 92 :returns: Python module
93 93 """
94 94 # Try to import and get the upgrade() func
95 95 module = import_path(path)
96 96 try:
97 97 assert callable(module.upgrade)
98 98 except Exception, e:
99 99 raise InvalidScriptError(path + ': %s' % str(e))
100 100 return module
101 101
102 102 def preview_sql(self, url, step, **args):
103 103 """Mocks SQLAlchemy Engine to store all executed calls in a string
104 104 and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
105 105
106 106 :returns: SQL file
107 107 """
108 108 buf = StringIO()
109 109 args['engine_arg_strategy'] = 'mock'
110 110 args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
111 111
112 112 @with_engine
113 113 def go(url, step, **kw):
114 114 engine = kw.pop('engine')
115 115 self.run(engine, step)
116 116 return buf.getvalue()
117 117
118 118 return go(url, step, **args)
119 119
120 120 def run(self, engine, step):
121 121 """Core method of Script file.
122 122 Exectues :func:`update` or :func:`downgrade` functions
123 123
124 124 :param engine: SQLAlchemy Engine
125 125 :param step: Operation to run
126 126 :type engine: string
127 127 :type step: int
128 128 """
129 129 if step > 0:
130 130 op = 'upgrade'
131 131 elif step < 0:
132 132 op = 'downgrade'
133 133 else:
134 134 raise ScriptError("%d is not a valid step" % step)
135 135
136 136 funcname = base.operations[op]
137 137 script_func = self._func(funcname)
138 138
139 139 try:
140 140 script_func(engine)
141 141 except TypeError:
142 142 warnings.warn("upgrade/downgrade functions must accept engine"
143 143 " parameter (since version > 0.5.4)", MigrateDeprecationWarning)
144 144 raise
145 145
146 146 @property
147 147 def module(self):
148 148 """Calls :meth:`migrate.versioning.script.py.verify_module`
149 149 and returns it.
150 150 """
151 151 if not hasattr(self, '_module'):
152 152 self._module = self.verify_module(self.path)
153 153 return self._module
154 154
155 155 def _func(self, funcname):
156 156 if not hasattr(self.module, funcname):
157 157 msg = "Function '%s' is not defined in this script"
158 158 raise ScriptError(msg % funcname)
159 159 return getattr(self.module, funcname)
@@ -1,49 +1,49 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3 import logging
4 4 import shutil
5 5
6 from migrate.versioning.script import base
7 from migrate.versioning.template import Template
6 from rhodecode.lib.dbmigrate.migrate.versioning.script import base
7 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
8 8
9 9
10 10 log = logging.getLogger(__name__)
11 11
12 12 class SqlScript(base.BaseScript):
13 13 """A file containing plain SQL statements."""
14 14
15 15 @classmethod
16 16 def create(cls, path, **opts):
17 17 """Create an empty migration script at specified path
18 18
19 19 :returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
20 20 cls.require_notfound(path)
21 21
22 22 src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None))
23 23 shutil.copy(src, path)
24 24 return cls(path)
25 25
26 26 # TODO: why is step parameter even here?
27 27 def run(self, engine, step=None, executemany=True):
28 28 """Runs SQL script through raw dbapi execute call"""
29 29 text = self.source()
30 30 # Don't rely on SA's autocommit here
31 31 # (SA uses .startswith to check if a commit is needed. What if script
32 32 # starts with a comment?)
33 33 conn = engine.connect()
34 34 try:
35 35 trans = conn.begin()
36 36 try:
37 37 # HACK: SQLite doesn't allow multiple statements through
38 38 # its execute() method, but it provides executescript() instead
39 39 dbapi = conn.engine.raw_connection()
40 40 if executemany and getattr(dbapi, 'executescript', None):
41 41 dbapi.executescript(text)
42 42 else:
43 43 conn.execute(text)
44 44 trans.commit()
45 45 except:
46 46 trans.rollback()
47 47 raise
48 48 finally:
49 49 conn.close()
@@ -1,215 +1,215 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3
4 4 """The migrate command-line tool."""
5 5
6 6 import sys
7 7 import inspect
8 8 import logging
9 9 from optparse import OptionParser, BadOptionError
10 10
11 from migrate import exceptions
12 from migrate.versioning import api
13 from migrate.versioning.config import *
14 from migrate.versioning.util import asbool
11 from rhodecode.lib.dbmigrate.migrate import exceptions
12 from rhodecode.lib.dbmigrate.migrate.versioning import api
13 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
14 from rhodecode.lib.dbmigrate.migrate.versioning.util import asbool
15 15
16 16
17 17 alias = dict(
18 18 s=api.script,
19 19 vc=api.version_control,
20 20 dbv=api.db_version,
21 21 v=api.version,
22 22 )
23 23
24 24 def alias_setup():
25 25 global alias
26 26 for key, val in alias.iteritems():
27 27 setattr(api, key, val)
28 28 alias_setup()
29 29
30 30
31 31 class PassiveOptionParser(OptionParser):
32 32
33 33 def _process_args(self, largs, rargs, values):
34 34 """little hack to support all --some_option=value parameters"""
35 35
36 36 while rargs:
37 37 arg = rargs[0]
38 38 if arg == "--":
39 39 del rargs[0]
40 40 return
41 41 elif arg[0:2] == "--":
42 42 # if parser does not know about the option
43 43 # pass it along (make it anonymous)
44 44 try:
45 45 opt = arg.split('=', 1)[0]
46 46 self._match_long_opt(opt)
47 47 except BadOptionError:
48 48 largs.append(arg)
49 49 del rargs[0]
50 50 else:
51 51 self._process_long_opt(rargs, values)
52 52 elif arg[:1] == "-" and len(arg) > 1:
53 53 self._process_short_opts(rargs, values)
54 54 elif self.allow_interspersed_args:
55 55 largs.append(arg)
56 56 del rargs[0]
57 57
58 58 def main(argv=None, **kwargs):
59 59 """Shell interface to :mod:`migrate.versioning.api`.
60 60
61 61 kwargs are default options that can be overriden with passing
62 62 --some_option as command line option
63 63
64 64 :param disable_logging: Let migrate configure logging
65 65 :type disable_logging: bool
66 66 """
67 67 if argv is not None:
68 68 argv = argv
69 69 else:
70 70 argv = list(sys.argv[1:])
71 71 commands = list(api.__all__)
72 72 commands.sort()
73 73
74 74 usage = """%%prog COMMAND ...
75 75
76 76 Available commands:
77 77 %s
78 78
79 79 Enter "%%prog help COMMAND" for information on a particular command.
80 80 """ % '\n\t'.join(["%s - %s" % (command.ljust(28),
81 81 api.command_desc.get(command)) for command in commands])
82 82
83 83 parser = PassiveOptionParser(usage=usage)
84 84 parser.add_option("-d", "--debug",
85 85 action="store_true",
86 86 dest="debug",
87 87 default=False,
88 88 help="Shortcut to turn on DEBUG mode for logging")
89 89 parser.add_option("-q", "--disable_logging",
90 90 action="store_true",
91 91 dest="disable_logging",
92 92 default=False,
93 93 help="Use this option to disable logging configuration")
94 94 help_commands = ['help', '-h', '--help']
95 95 HELP = False
96 96
97 97 try:
98 98 command = argv.pop(0)
99 99 if command in help_commands:
100 100 HELP = True
101 101 command = argv.pop(0)
102 102 except IndexError:
103 103 parser.print_help()
104 104 return
105 105
106 106 command_func = getattr(api, command, None)
107 107 if command_func is None or command.startswith('_'):
108 108 parser.error("Invalid command %s" % command)
109 109
110 110 parser.set_usage(inspect.getdoc(command_func))
111 111 f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
112 112 for arg in f_args:
113 113 parser.add_option(
114 114 "--%s" % arg,
115 115 dest=arg,
116 116 action='store',
117 117 type="string")
118 118
119 119 # display help of the current command
120 120 if HELP:
121 121 parser.print_help()
122 122 return
123 123
124 124 options, args = parser.parse_args(argv)
125 125
126 126 # override kwargs with anonymous parameters
127 127 override_kwargs = dict()
128 128 for arg in list(args):
129 129 if arg.startswith('--'):
130 130 args.remove(arg)
131 131 if '=' in arg:
132 132 opt, value = arg[2:].split('=', 1)
133 133 else:
134 134 opt = arg[2:]
135 135 value = True
136 136 override_kwargs[opt] = value
137 137
138 138 # override kwargs with options if user is overwriting
139 139 for key, value in options.__dict__.iteritems():
140 140 if value is not None:
141 141 override_kwargs[key] = value
142 142
143 143 # arguments that function accepts without passed kwargs
144 144 f_required = list(f_args)
145 145 candidates = dict(kwargs)
146 146 candidates.update(override_kwargs)
147 147 for key, value in candidates.iteritems():
148 148 if key in f_args:
149 149 f_required.remove(key)
150 150
151 151 # map function arguments to parsed arguments
152 152 for arg in args:
153 153 try:
154 154 kw = f_required.pop(0)
155 155 except IndexError:
156 156 parser.error("Too many arguments for command %s: %s" % (command,
157 157 arg))
158 158 kwargs[kw] = arg
159 159
160 160 # apply overrides
161 161 kwargs.update(override_kwargs)
162 162
163 163 # configure options
164 164 for key, value in options.__dict__.iteritems():
165 165 kwargs.setdefault(key, value)
166 166
167 167 # configure logging
168 168 if not asbool(kwargs.pop('disable_logging', False)):
169 169 # filter to log =< INFO into stdout and rest to stderr
170 170 class SingleLevelFilter(logging.Filter):
171 171 def __init__(self, min=None, max=None):
172 172 self.min = min or 0
173 173 self.max = max or 100
174 174
175 175 def filter(self, record):
176 176 return self.min <= record.levelno <= self.max
177 177
178 178 logger = logging.getLogger()
179 179 h1 = logging.StreamHandler(sys.stdout)
180 180 f1 = SingleLevelFilter(max=logging.INFO)
181 181 h1.addFilter(f1)
182 182 h2 = logging.StreamHandler(sys.stderr)
183 183 f2 = SingleLevelFilter(min=logging.WARN)
184 184 h2.addFilter(f2)
185 185 logger.addHandler(h1)
186 186 logger.addHandler(h2)
187 187
188 188 if options.debug:
189 189 logger.setLevel(logging.DEBUG)
190 190 else:
191 191 logger.setLevel(logging.INFO)
192 192
193 193 log = logging.getLogger(__name__)
194 194
195 195 # check if all args are given
196 196 try:
197 197 num_defaults = len(f_defaults)
198 198 except TypeError:
199 199 num_defaults = 0
200 200 f_args_default = f_args[len(f_args) - num_defaults:]
201 201 required = list(set(f_required) - set(f_args_default))
202 202 if required:
203 203 parser.error("Not enough arguments for command %s: %s not specified" \
204 204 % (command, ', '.join(required)))
205 205
206 206 # handle command
207 207 try:
208 208 ret = command_func(**kwargs)
209 209 if ret is not None:
210 210 log.info(ret)
211 211 except (exceptions.UsageError, exceptions.KnownError), e:
212 212 parser.error(e.args[0])
213 213
214 214 if __name__ == "__main__":
215 215 main()
@@ -1,94 +1,94 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3
4 4 import os
5 5 import shutil
6 6 import sys
7 7
8 8 from pkg_resources import resource_filename
9 9
10 from migrate.versioning.config import *
11 from migrate.versioning import pathed
10 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
11 from rhodecode.lib.dbmigrate.migrate.versioning import pathed
12 12
13 13
14 14 class Collection(pathed.Pathed):
15 15 """A collection of templates of a specific type"""
16 16 _mask = None
17 17
18 18 def get_path(self, file):
19 19 return os.path.join(self.path, str(file))
20 20
21 21
22 22 class RepositoryCollection(Collection):
23 23 _mask = '%s'
24 24
25 25 class ScriptCollection(Collection):
26 26 _mask = '%s.py_tmpl'
27 27
28 28 class ManageCollection(Collection):
29 29 _mask = '%s.py_tmpl'
30 30
31 31 class SQLScriptCollection(Collection):
32 32 _mask = '%s.py_tmpl'
33 33
34 34 class Template(pathed.Pathed):
35 35 """Finds the paths/packages of various Migrate templates.
36 36
37 :param path: Templates are loaded from migrate package
37 :param path: Templates are loaded from rhodecode.lib.dbmigrate.migrate package
38 38 if `path` is not provided.
39 39 """
40 40 pkg = 'migrate.versioning.templates'
41 41 _manage = 'manage.py_tmpl'
42 42
43 43 def __new__(cls, path=None):
44 44 if path is None:
45 45 path = cls._find_path(cls.pkg)
46 46 return super(Template, cls).__new__(cls, path)
47 47
48 48 def __init__(self, path=None):
49 49 if path is None:
50 50 path = Template._find_path(self.pkg)
51 51 super(Template, self).__init__(path)
52 52 self.repository = RepositoryCollection(os.path.join(path, 'repository'))
53 53 self.script = ScriptCollection(os.path.join(path, 'script'))
54 54 self.manage = ManageCollection(os.path.join(path, 'manage'))
55 55 self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script'))
56 56
57 57 @classmethod
58 58 def _find_path(cls, pkg):
59 59 """Returns absolute path to dotted python package."""
60 60 tmp_pkg = pkg.rsplit('.', 1)
61 61
62 62 if len(tmp_pkg) != 1:
63 63 return resource_filename(tmp_pkg[0], tmp_pkg[1])
64 64 else:
65 65 return resource_filename(tmp_pkg[0], '')
66 66
67 67 def _get_item(self, collection, theme=None):
68 68 """Locates and returns collection.
69 69
70 70 :param collection: name of collection to locate
71 71 :param type_: type of subfolder in collection (defaults to "_default")
72 72 :returns: (package, source)
73 73 :rtype: str, str
74 74 """
75 75 item = getattr(self, collection)
76 76 theme_mask = getattr(item, '_mask')
77 77 theme = theme_mask % (theme or 'default')
78 78 return item.get_path(theme)
79 79
80 80 def get_repository(self, *a, **kw):
81 81 """Calls self._get_item('repository', *a, **kw)"""
82 82 return self._get_item('repository', *a, **kw)
83 83
84 84 def get_script(self, *a, **kw):
85 85 """Calls self._get_item('script', *a, **kw)"""
86 86 return self._get_item('script', *a, **kw)
87 87
88 88 def get_sql_script(self, *a, **kw):
89 89 """Calls self._get_item('sql_script', *a, **kw)"""
90 90 return self._get_item('sql_script', *a, **kw)
91 91
92 92 def get_manage(self, *a, **kw):
93 93 """Calls self._get_item('manage', *a, **kw)"""
94 94 return self._get_item('manage', *a, **kw)
@@ -1,179 +1,179 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3 """.. currentmodule:: migrate.versioning.util"""
4 4
5 5 import warnings
6 6 import logging
7 7 from decorator import decorator
8 8 from pkg_resources import EntryPoint
9 9
10 10 from sqlalchemy import create_engine
11 11 from sqlalchemy.engine import Engine
12 12 from sqlalchemy.pool import StaticPool
13 13
14 from migrate import exceptions
15 from migrate.versioning.util.keyedinstance import KeyedInstance
16 from migrate.versioning.util.importpath import import_path
14 from rhodecode.lib.dbmigrate.migrate import exceptions
15 from rhodecode.lib.dbmigrate.migrate.versioning.util.keyedinstance import KeyedInstance
16 from rhodecode.lib.dbmigrate.migrate.versioning.util.importpath import import_path
17 17
18 18
19 19 log = logging.getLogger(__name__)
20 20
21 21 def load_model(dotted_name):
22 22 """Import module and use module-level variable".
23 23
24 24 :param dotted_name: path to model in form of string: ``some.python.module:Class``
25 25
26 26 .. versionchanged:: 0.5.4
27 27
28 28 """
29 29 if isinstance(dotted_name, basestring):
30 30 if ':' not in dotted_name:
31 31 # backwards compatibility
32 32 warnings.warn('model should be in form of module.model:User '
33 33 'and not module.model.User', exceptions.MigrateDeprecationWarning)
34 34 dotted_name = ':'.join(dotted_name.rsplit('.', 1))
35 35 return EntryPoint.parse('x=%s' % dotted_name).load(False)
36 36 else:
37 37 # Assume it's already loaded.
38 38 return dotted_name
39 39
40 40 def asbool(obj):
41 41 """Do everything to use object as bool"""
42 42 if isinstance(obj, basestring):
43 43 obj = obj.strip().lower()
44 44 if obj in ['true', 'yes', 'on', 'y', 't', '1']:
45 45 return True
46 46 elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
47 47 return False
48 48 else:
49 49 raise ValueError("String is not true/false: %r" % obj)
50 50 if obj in (True, False):
51 51 return bool(obj)
52 52 else:
53 53 raise ValueError("String is not true/false: %r" % obj)
54 54
55 55 def guess_obj_type(obj):
56 56 """Do everything to guess object type from string
57 57
58 58 Tries to convert to `int`, `bool` and finally returns if not succeded.
59 59
60 60 .. versionadded: 0.5.4
61 61 """
62 62
63 63 result = None
64 64
65 65 try:
66 66 result = int(obj)
67 67 except:
68 68 pass
69 69
70 70 if result is None:
71 71 try:
72 72 result = asbool(obj)
73 73 except:
74 74 pass
75 75
76 76 if result is not None:
77 77 return result
78 78 else:
79 79 return obj
80 80
81 81 @decorator
82 82 def catch_known_errors(f, *a, **kw):
83 83 """Decorator that catches known api errors
84 84
85 85 .. versionadded: 0.5.4
86 86 """
87 87
88 88 try:
89 89 return f(*a, **kw)
90 90 except exceptions.PathFoundError, e:
91 91 raise exceptions.KnownError("The path %s already exists" % e.args[0])
92 92
93 93 def construct_engine(engine, **opts):
94 94 """.. versionadded:: 0.5.4
95 95
96 96 Constructs and returns SQLAlchemy engine.
97 97
98 98 Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions:
99 99
100 100 :param engine: connection string or a existing engine
101 101 :param engine_dict: python dictionary of options to pass to `create_engine`
102 102 :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`)
103 103 :type engine_dict: dict
104 104 :type engine: string or Engine instance
105 105 :type engine_arg_*: string
106 106 :returns: SQLAlchemy Engine
107 107
108 108 .. note::
109 109
110 110 keyword parameters override ``engine_dict`` values.
111 111
112 112 """
113 113 if isinstance(engine, Engine):
114 114 return engine
115 115 elif not isinstance(engine, basestring):
116 116 raise ValueError("you need to pass either an existing engine or a database uri")
117 117
118 118 # get options for create_engine
119 119 if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
120 120 kwargs = opts['engine_dict']
121 121 else:
122 122 kwargs = dict()
123 123
124 124 # DEPRECATED: handle echo the old way
125 125 echo = asbool(opts.get('echo', False))
126 126 if echo:
127 127 warnings.warn('echo=True parameter is deprecated, pass '
128 128 'engine_arg_echo=True or engine_dict={"echo": True}',
129 129 exceptions.MigrateDeprecationWarning)
130 130 kwargs['echo'] = echo
131 131
132 132 # parse keyword arguments
133 133 for key, value in opts.iteritems():
134 134 if key.startswith('engine_arg_'):
135 135 kwargs[key[11:]] = guess_obj_type(value)
136 136
137 137 log.debug('Constructing engine')
138 138 # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs)
139 139 # seems like 0.5.x branch does not work with engine.dispose and staticpool
140 140 return create_engine(engine, **kwargs)
141 141
142 142 @decorator
143 143 def with_engine(f, *a, **kw):
144 144 """Decorator for :mod:`migrate.versioning.api` functions
145 145 to safely close resources after function usage.
146 146
147 147 Passes engine parameters to :func:`construct_engine` and
148 148 resulting parameter is available as kw['engine'].
149 149
150 150 Engine is disposed after wrapped function is executed.
151 151
152 152 .. versionadded: 0.6.0
153 153 """
154 154 url = a[0]
155 155 engine = construct_engine(url, **kw)
156 156
157 157 try:
158 158 kw['engine'] = engine
159 159 return f(*a, **kw)
160 160 finally:
161 161 if isinstance(engine, Engine):
162 162 log.debug('Disposing SQLAlchemy engine %s', engine)
163 163 engine.dispose()
164 164
165 165
166 166 class Memoize:
167 167 """Memoize(fn) - an instance which acts like fn but memoizes its arguments
168 168 Will only work on functions with non-mutable arguments
169 169
170 170 ActiveState Code 52201
171 171 """
172 172 def __init__(self, fn):
173 173 self.fn = fn
174 174 self.memo = {}
175 175
176 176 def __call__(self, *args):
177 177 if not self.memo.has_key(args):
178 178 self.memo[args] = self.fn(*args)
179 179 return self.memo[args]
@@ -1,215 +1,215 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3
4 4 import os
5 5 import re
6 6 import shutil
7 7 import logging
8 8
9 from migrate import exceptions
10 from migrate.versioning import pathed, script
9 from rhodecode.lib.dbmigrate.migrate import exceptions
10 from rhodecode.lib.dbmigrate.migrate.versioning import pathed, script
11 11
12 12
13 13 log = logging.getLogger(__name__)
14 14
15 15 class VerNum(object):
16 16 """A version number that behaves like a string and int at the same time"""
17 17
18 18 _instances = dict()
19 19
20 20 def __new__(cls, value):
21 21 val = str(value)
22 22 if val not in cls._instances:
23 23 cls._instances[val] = super(VerNum, cls).__new__(cls)
24 24 ret = cls._instances[val]
25 25 return ret
26 26
27 27 def __init__(self,value):
28 28 self.value = str(int(value))
29 29 if self < 0:
30 30 raise ValueError("Version number cannot be negative")
31 31
32 32 def __add__(self, value):
33 33 ret = int(self) + int(value)
34 34 return VerNum(ret)
35 35
36 36 def __sub__(self, value):
37 37 return self + (int(value) * -1)
38 38
39 39 def __cmp__(self, value):
40 40 return int(self) - int(value)
41 41
42 42 def __repr__(self):
43 43 return "<VerNum(%s)>" % self.value
44 44
45 45 def __str__(self):
46 46 return str(self.value)
47 47
48 48 def __int__(self):
49 49 return int(self.value)
50 50
51 51
52 52 class Collection(pathed.Pathed):
53 53 """A collection of versioning scripts in a repository"""
54 54
55 55 FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
56 56
57 57 def __init__(self, path):
58 58 """Collect current version scripts in repository
59 59 and store them in self.versions
60 60 """
61 61 super(Collection, self).__init__(path)
62 62
63 63 # Create temporary list of files, allowing skipped version numbers.
64 64 files = os.listdir(path)
65 65 if '1' in files:
66 66 # deprecation
67 67 raise Exception('It looks like you have a repository in the old '
68 68 'format (with directories for each version). '
69 69 'Please convert repository before proceeding.')
70 70
71 71 tempVersions = dict()
72 72 for filename in files:
73 73 match = self.FILENAME_WITH_VERSION.match(filename)
74 74 if match:
75 75 num = int(match.group(1))
76 76 tempVersions.setdefault(num, []).append(filename)
77 77 else:
78 78 pass # Must be a helper file or something, let's ignore it.
79 79
80 80 # Create the versions member where the keys
81 81 # are VerNum's and the values are Version's.
82 82 self.versions = dict()
83 83 for num, files in tempVersions.items():
84 84 self.versions[VerNum(num)] = Version(num, path, files)
85 85
86 86 @property
87 87 def latest(self):
88 88 """:returns: Latest version in Collection"""
89 89 return max([VerNum(0)] + self.versions.keys())
90 90
91 91 def create_new_python_version(self, description, **k):
92 92 """Create Python files for new version"""
93 93 ver = self.latest + 1
94 94 extra = str_to_filename(description)
95 95
96 96 if extra:
97 97 if extra == '_':
98 98 extra = ''
99 99 elif not extra.startswith('_'):
100 100 extra = '_%s' % extra
101 101
102 102 filename = '%03d%s.py' % (ver, extra)
103 103 filepath = self._version_path(filename)
104 104
105 105 script.PythonScript.create(filepath, **k)
106 106 self.versions[ver] = Version(ver, self.path, [filename])
107 107
108 108 def create_new_sql_version(self, database, **k):
109 109 """Create SQL files for new version"""
110 110 ver = self.latest + 1
111 111 self.versions[ver] = Version(ver, self.path, [])
112 112
113 113 # Create new files.
114 114 for op in ('upgrade', 'downgrade'):
115 115 filename = '%03d_%s_%s.sql' % (ver, database, op)
116 116 filepath = self._version_path(filename)
117 117 script.SqlScript.create(filepath, **k)
118 118 self.versions[ver].add_script(filepath)
119 119
120 120 def version(self, vernum=None):
121 121 """Returns latest Version if vernum is not given.
122 122 Otherwise, returns wanted version"""
123 123 if vernum is None:
124 124 vernum = self.latest
125 125 return self.versions[VerNum(vernum)]
126 126
127 127 @classmethod
128 128 def clear(cls):
129 129 super(Collection, cls).clear()
130 130
131 131 def _version_path(self, ver):
132 132 """Returns path of file in versions repository"""
133 133 return os.path.join(self.path, str(ver))
134 134
135 135
136 136 class Version(object):
137 137 """A single version in a collection
138 138 :param vernum: Version Number
139 139 :param path: Path to script files
140 140 :param filelist: List of scripts
141 141 :type vernum: int, VerNum
142 142 :type path: string
143 143 :type filelist: list
144 144 """
145 145
146 146 def __init__(self, vernum, path, filelist):
147 147 self.version = VerNum(vernum)
148 148
149 149 # Collect scripts in this folder
150 150 self.sql = dict()
151 151 self.python = None
152 152
153 153 for script in filelist:
154 154 self.add_script(os.path.join(path, script))
155 155
156 156 def script(self, database=None, operation=None):
157 157 """Returns SQL or Python Script"""
158 158 for db in (database, 'default'):
159 159 # Try to return a .sql script first
160 160 try:
161 161 return self.sql[db][operation]
162 162 except KeyError:
163 163 continue # No .sql script exists
164 164
165 165 # TODO: maybe add force Python parameter?
166 166 ret = self.python
167 167
168 168 assert ret is not None, \
169 169 "There is no script for %d version" % self.version
170 170 return ret
171 171
172 172 def add_script(self, path):
173 173 """Add script to Collection/Version"""
174 174 if path.endswith(Extensions.py):
175 175 self._add_script_py(path)
176 176 elif path.endswith(Extensions.sql):
177 177 self._add_script_sql(path)
178 178
179 179 SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
180 180
181 181 def _add_script_sql(self, path):
182 182 basename = os.path.basename(path)
183 183 match = self.SQL_FILENAME.match(basename)
184 184
185 185 if match:
186 186 version, dbms, op = match.group(1), match.group(2), match.group(3)
187 187 else:
188 188 raise exceptions.ScriptError(
189 189 "Invalid SQL script name %s " % basename + \
190 190 "(needs to be ###_database_operation.sql)")
191 191
192 192 # File the script into a dictionary
193 193 self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
194 194
195 195 def _add_script_py(self, path):
196 196 if self.python is not None:
197 197 raise exceptions.ScriptError('You can only have one Python script '
198 198 'per version, but you have: %s and %s' % (self.python, path))
199 199 self.python = script.PythonScript(path)
200 200
201 201
202 202 class Extensions:
203 203 """A namespace for file extensions"""
204 204 py = 'py'
205 205 sql = 'sql'
206 206
207 207 def str_to_filename(s):
208 208 """Replaces spaces, (double and single) quotes
209 209 and double underscores to underscores
210 210 """
211 211
212 212 s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
213 213 while '__' in s:
214 214 s = s.replace('__', '_')
215 215 return s
General Comments 0
You need to be logged in to leave comments. Login now