##// END OF EJS Templates
fixed imports on migrate, added getting current version from database
marcink -
r835:08d2dcd7 beta
parent child Browse files
Show More
@@ -0,0 +1,20 b''
1 [db_settings]
2 # Used to identify which repository this database is versioned under.
3 # You can use the name of your project.
4 repository_id=rhodecode_db_migrations
5
6 # The name of the database table used to track the schema version.
7 # This name shouldn't already be used by your project.
8 # If this is changed once a database is under version control, you'll need to
9 # change the table name in each database too.
10 version_table=db_migrate_version
11
12 # When committing a change script, Migrate will attempt to generate the
13 # sql for all supported databases; normally, if one of them fails - probably
14 # because you don't have that database installed - it is ignored and the
15 # commit continues, perhaps ending successfully.
16 # Databases in this list MUST compile successfully during a commit, or the
17 # entire commit will fail. List the databases your application will actually
18 # be using to ensure your updates to that database work properly.
19 # This must be a list; example: ['postgres','sqlite']
20 required_dbs=['sqlite']
@@ -1,328 +1,332 b''
1 #!/usr/bin/env python
1 # -*- coding: utf-8 -*-
2 # encoding: utf-8
2 """
3 # database management for RhodeCode
3 rhodecode.lib.db_manage
4 # Copyright (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
4 ~~~~~~~~~~~~~~~~~~~~~~~
5 #
5
6 Database creation, and setup module for RhodeCode
7
8 :created_on: Apr 10, 2010
9 :author: marcink
10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 :license: GPLv3, see COPYING for more details.
12 """
6 # This program is free software; you can redistribute it and/or
13 # This program is free software; you can redistribute it and/or
7 # modify it under the terms of the GNU General Public License
14 # modify it under the terms of the GNU General Public License
8 # as published by the Free Software Foundation; version 2
15 # as published by the Free Software Foundation; version 2
9 # of the License or (at your opinion) any later version of the license.
16 # of the License or (at your opinion) any later version of the license.
10 #
17 #
11 # This program is distributed in the hope that it will be useful,
18 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
21 # GNU General Public License for more details.
15 #
22 #
16 # You should have received a copy of the GNU General Public License
23 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
24 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
19 # MA 02110-1301, USA.
26 # MA 02110-1301, USA.
20
27
21 """
22 Created on April 10, 2010
23 database management and creation for RhodeCode
24 @author: marcink
25 """
26
27 from os.path import dirname as dn, join as jn
28 import os
28 import os
29 import sys
29 import sys
30 import uuid
30 import uuid
31 import logging
32 from os.path import dirname as dn, join as jn
33
34 from rhodecode import __dbversion__
35 from rhodecode.model.db import
36 from rhodecode.model import meta
31
37
32 from rhodecode.lib.auth import get_crypt_password
38 from rhodecode.lib.auth import get_crypt_password
33 from rhodecode.lib.utils import ask_ok
39 from rhodecode.lib.utils import ask_ok
34 from rhodecode.model import init_model
40 from rhodecode.model import init_model
35 from rhodecode.model.db import User, Permission, RhodeCodeUi, RhodeCodeSettings, \
41 from rhodecode.model.db import User, Permission, RhodeCodeUi, RhodeCodeSettings, \
36 UserToPerm
42 UserToPerm, DbMigrateVersion
37 from rhodecode.model import meta
43
38 from sqlalchemy.engine import create_engine
44 from sqlalchemy.engine import create_engine
39 import logging
45
40
46
41 log = logging.getLogger(__name__)
47 log = logging.getLogger(__name__)
42
48
43 class DbManage(object):
49 class DbManage(object):
44 def __init__(self, log_sql, dbconf, root, tests=False):
50 def __init__(self, log_sql, dbconf, root, tests=False):
45 self.dbname = dbconf.split('/')[-1]
51 self.dbname = dbconf.split('/')[-1]
46 self.tests = tests
52 self.tests = tests
47 self.root = root
53 self.root = root
48 self.dburi = dbconf
54 self.dburi = dbconf
49 engine = create_engine(self.dburi, echo=log_sql)
55 engine = create_engine(self.dburi, echo=log_sql)
50 init_model(engine)
56 init_model(engine)
51 self.sa = meta.Session()
57 self.sa = meta.Session()
52 self.db_exists = False
58 self.db_exists = False
53
59
54 def check_for_db(self, override):
60 def check_for_db(self, override):
55 db_path = jn(self.root, self.dbname)
61 db_path = jn(self.root, self.dbname)
56 if self.dburi.startswith('sqlite'):
62 if self.dburi.startswith('sqlite'):
57 log.info('checking for existing db in %s', db_path)
63 log.info('checking for existing db in %s', db_path)
58 if os.path.isfile(db_path):
64 if os.path.isfile(db_path):
59
65
60 self.db_exists = True
66 self.db_exists = True
61 if not override:
67 if not override:
62 raise Exception('database already exists')
68 raise Exception('database already exists')
63
69
64 def create_tables(self, override=False):
70 def create_tables(self, override=False):
65 """
71 """
66 Create a auth database
72 Create a auth database
67 """
73 """
68 self.check_for_db(override)
74 self.check_for_db(override)
69 if self.db_exists:
75 if self.db_exists:
70 log.info("database exist and it's going to be destroyed")
76 log.info("database exist and it's going to be destroyed")
71 if self.tests:
77 if self.tests:
72 destroy = True
78 destroy = True
73 else:
79 else:
74 destroy = ask_ok('Are you sure to destroy old database ? [y/n]')
80 destroy = ask_ok('Are you sure to destroy old database ? [y/n]')
75 if not destroy:
81 if not destroy:
76 sys.exit()
82 sys.exit()
77 if self.db_exists and destroy:
83 if self.db_exists and destroy:
78 os.remove(jn(self.root, self.dbname))
84 os.remove(jn(self.root, self.dbname))
79 checkfirst = not override
85 checkfirst = not override
80 meta.Base.metadata.create_all(checkfirst=checkfirst)
86 meta.Base.metadata.create_all(checkfirst=checkfirst)
81 log.info('Created tables for %s', self.dbname)
87 log.info('Created tables for %s', self.dbname)
82
88
83
89
84
90
85 def set_db_version(self):
91 def set_db_version(self):
86 from rhodecode import __dbversion__
87 from rhodecode.model.db import DbMigrateVersion
88 try:
92 try:
89 ver = DbMigrateVersion()
93 ver = DbMigrateVersion()
90 ver.version = __dbversion__
94 ver.version = __dbversion__
91 ver.repository_id = 'rhodecode_db_migrations'
95 ver.repository_id = 'rhodecode_db_migrations'
92 ver.repository_path = 'versions'
96 ver.repository_path = 'versions'
93 self.sa.add(ver)
97 self.sa.add(ver)
94 self.sa.commit()
98 self.sa.commit()
95 except:
99 except:
96 self.sa.rollback()
100 self.sa.rollback()
97 raise
101 raise
98 log.info('db version set to: %s', __dbversion__)
102 log.info('db version set to: %s', __dbversion__)
99
103
100 def admin_prompt(self, second=False):
104 def admin_prompt(self, second=False):
101 if not self.tests:
105 if not self.tests:
102 import getpass
106 import getpass
103
107
104
108
105 def get_password():
109 def get_password():
106 password = getpass.getpass('Specify admin password (min 6 chars):')
110 password = getpass.getpass('Specify admin password (min 6 chars):')
107 confirm = getpass.getpass('Confirm password:')
111 confirm = getpass.getpass('Confirm password:')
108
112
109 if password != confirm:
113 if password != confirm:
110 log.error('passwords mismatch')
114 log.error('passwords mismatch')
111 return False
115 return False
112 if len(password) < 6:
116 if len(password) < 6:
113 log.error('password is to short use at least 6 characters')
117 log.error('password is to short use at least 6 characters')
114 return False
118 return False
115
119
116 return password
120 return password
117
121
118 username = raw_input('Specify admin username:')
122 username = raw_input('Specify admin username:')
119
123
120 password = get_password()
124 password = get_password()
121 if not password:
125 if not password:
122 #second try
126 #second try
123 password = get_password()
127 password = get_password()
124 if not password:
128 if not password:
125 sys.exit()
129 sys.exit()
126
130
127 email = raw_input('Specify admin email:')
131 email = raw_input('Specify admin email:')
128 self.create_user(username, password, email, True)
132 self.create_user(username, password, email, True)
129 else:
133 else:
130 log.info('creating admin and regular test users')
134 log.info('creating admin and regular test users')
131 self.create_user('test_admin', 'test12', 'test_admin@mail.com', True)
135 self.create_user('test_admin', 'test12', 'test_admin@mail.com', True)
132 self.create_user('test_regular', 'test12', 'test_regular@mail.com', False)
136 self.create_user('test_regular', 'test12', 'test_regular@mail.com', False)
133 self.create_user('test_regular2', 'test12', 'test_regular2@mail.com', False)
137 self.create_user('test_regular2', 'test12', 'test_regular2@mail.com', False)
134
138
135
139
136
140
137 def config_prompt(self, test_repo_path=''):
141 def config_prompt(self, test_repo_path=''):
138 log.info('Setting up repositories config')
142 log.info('Setting up repositories config')
139
143
140 if not self.tests and not test_repo_path:
144 if not self.tests and not test_repo_path:
141 path = raw_input('Specify valid full path to your repositories'
145 path = raw_input('Specify valid full path to your repositories'
142 ' you can change this later in application settings:')
146 ' you can change this later in application settings:')
143 else:
147 else:
144 path = test_repo_path
148 path = test_repo_path
145
149
146 if not os.path.isdir(path):
150 if not os.path.isdir(path):
147 log.error('You entered wrong path: %s', path)
151 log.error('You entered wrong path: %s', path)
148 sys.exit()
152 sys.exit()
149
153
150 hooks1 = RhodeCodeUi()
154 hooks1 = RhodeCodeUi()
151 hooks1.ui_section = 'hooks'
155 hooks1.ui_section = 'hooks'
152 hooks1.ui_key = 'changegroup.update'
156 hooks1.ui_key = 'changegroup.update'
153 hooks1.ui_value = 'hg update >&2'
157 hooks1.ui_value = 'hg update >&2'
154 hooks1.ui_active = False
158 hooks1.ui_active = False
155
159
156 hooks2 = RhodeCodeUi()
160 hooks2 = RhodeCodeUi()
157 hooks2.ui_section = 'hooks'
161 hooks2.ui_section = 'hooks'
158 hooks2.ui_key = 'changegroup.repo_size'
162 hooks2.ui_key = 'changegroup.repo_size'
159 hooks2.ui_value = 'python:rhodecode.lib.hooks.repo_size'
163 hooks2.ui_value = 'python:rhodecode.lib.hooks.repo_size'
160
164
161 hooks3 = RhodeCodeUi()
165 hooks3 = RhodeCodeUi()
162 hooks3.ui_section = 'hooks'
166 hooks3.ui_section = 'hooks'
163 hooks3.ui_key = 'pretxnchangegroup.push_logger'
167 hooks3.ui_key = 'pretxnchangegroup.push_logger'
164 hooks3.ui_value = 'python:rhodecode.lib.hooks.log_push_action'
168 hooks3.ui_value = 'python:rhodecode.lib.hooks.log_push_action'
165
169
166 hooks4 = RhodeCodeUi()
170 hooks4 = RhodeCodeUi()
167 hooks4.ui_section = 'hooks'
171 hooks4.ui_section = 'hooks'
168 hooks4.ui_key = 'preoutgoing.pull_logger'
172 hooks4.ui_key = 'preoutgoing.pull_logger'
169 hooks4.ui_value = 'python:rhodecode.lib.hooks.log_pull_action'
173 hooks4.ui_value = 'python:rhodecode.lib.hooks.log_pull_action'
170
174
171 #for mercurial 1.7 set backward comapatibility with format
175 #for mercurial 1.7 set backward comapatibility with format
172
176
173 dotencode_disable = RhodeCodeUi()
177 dotencode_disable = RhodeCodeUi()
174 dotencode_disable.ui_section = 'format'
178 dotencode_disable.ui_section = 'format'
175 dotencode_disable.ui_key = 'dotencode'
179 dotencode_disable.ui_key = 'dotencode'
176 dotencode_disable.ui_section = 'false'
180 dotencode_disable.ui_section = 'false'
177
181
178
182
179 web1 = RhodeCodeUi()
183 web1 = RhodeCodeUi()
180 web1.ui_section = 'web'
184 web1.ui_section = 'web'
181 web1.ui_key = 'push_ssl'
185 web1.ui_key = 'push_ssl'
182 web1.ui_value = 'false'
186 web1.ui_value = 'false'
183
187
184 web2 = RhodeCodeUi()
188 web2 = RhodeCodeUi()
185 web2.ui_section = 'web'
189 web2.ui_section = 'web'
186 web2.ui_key = 'allow_archive'
190 web2.ui_key = 'allow_archive'
187 web2.ui_value = 'gz zip bz2'
191 web2.ui_value = 'gz zip bz2'
188
192
189 web3 = RhodeCodeUi()
193 web3 = RhodeCodeUi()
190 web3.ui_section = 'web'
194 web3.ui_section = 'web'
191 web3.ui_key = 'allow_push'
195 web3.ui_key = 'allow_push'
192 web3.ui_value = '*'
196 web3.ui_value = '*'
193
197
194 web4 = RhodeCodeUi()
198 web4 = RhodeCodeUi()
195 web4.ui_section = 'web'
199 web4.ui_section = 'web'
196 web4.ui_key = 'baseurl'
200 web4.ui_key = 'baseurl'
197 web4.ui_value = '/'
201 web4.ui_value = '/'
198
202
199 paths = RhodeCodeUi()
203 paths = RhodeCodeUi()
200 paths.ui_section = 'paths'
204 paths.ui_section = 'paths'
201 paths.ui_key = '/'
205 paths.ui_key = '/'
202 paths.ui_value = path
206 paths.ui_value = path
203
207
204
208
205 hgsettings1 = RhodeCodeSettings('realm', 'RhodeCode authentication')
209 hgsettings1 = RhodeCodeSettings('realm', 'RhodeCode authentication')
206 hgsettings2 = RhodeCodeSettings('title', 'RhodeCode')
210 hgsettings2 = RhodeCodeSettings('title', 'RhodeCode')
207
211
208
212
209 try:
213 try:
210 self.sa.add(hooks1)
214 self.sa.add(hooks1)
211 self.sa.add(hooks2)
215 self.sa.add(hooks2)
212 self.sa.add(hooks3)
216 self.sa.add(hooks3)
213 self.sa.add(hooks4)
217 self.sa.add(hooks4)
214 self.sa.add(web1)
218 self.sa.add(web1)
215 self.sa.add(web2)
219 self.sa.add(web2)
216 self.sa.add(web3)
220 self.sa.add(web3)
217 self.sa.add(web4)
221 self.sa.add(web4)
218 self.sa.add(paths)
222 self.sa.add(paths)
219 self.sa.add(hgsettings1)
223 self.sa.add(hgsettings1)
220 self.sa.add(hgsettings2)
224 self.sa.add(hgsettings2)
221 self.sa.add(dotencode_disable)
225 self.sa.add(dotencode_disable)
222 for k in ['ldap_active', 'ldap_host', 'ldap_port', 'ldap_ldaps',
226 for k in ['ldap_active', 'ldap_host', 'ldap_port', 'ldap_ldaps',
223 'ldap_dn_user', 'ldap_dn_pass', 'ldap_base_dn']:
227 'ldap_dn_user', 'ldap_dn_pass', 'ldap_base_dn']:
224
228
225 setting = RhodeCodeSettings(k, '')
229 setting = RhodeCodeSettings(k, '')
226 self.sa.add(setting)
230 self.sa.add(setting)
227
231
228 self.sa.commit()
232 self.sa.commit()
229 except:
233 except:
230 self.sa.rollback()
234 self.sa.rollback()
231 raise
235 raise
232 log.info('created ui config')
236 log.info('created ui config')
233
237
234 def create_user(self, username, password, email='', admin=False):
238 def create_user(self, username, password, email='', admin=False):
235 log.info('creating administrator user %s', username)
239 log.info('creating administrator user %s', username)
236 new_user = User()
240 new_user = User()
237 new_user.username = username
241 new_user.username = username
238 new_user.password = get_crypt_password(password)
242 new_user.password = get_crypt_password(password)
239 new_user.name = 'RhodeCode'
243 new_user.name = 'RhodeCode'
240 new_user.lastname = 'Admin'
244 new_user.lastname = 'Admin'
241 new_user.email = email
245 new_user.email = email
242 new_user.admin = admin
246 new_user.admin = admin
243 new_user.active = True
247 new_user.active = True
244
248
245 try:
249 try:
246 self.sa.add(new_user)
250 self.sa.add(new_user)
247 self.sa.commit()
251 self.sa.commit()
248 except:
252 except:
249 self.sa.rollback()
253 self.sa.rollback()
250 raise
254 raise
251
255
252 def create_default_user(self):
256 def create_default_user(self):
253 log.info('creating default user')
257 log.info('creating default user')
254 #create default user for handling default permissions.
258 #create default user for handling default permissions.
255 def_user = User()
259 def_user = User()
256 def_user.username = 'default'
260 def_user.username = 'default'
257 def_user.password = get_crypt_password(str(uuid.uuid1())[:8])
261 def_user.password = get_crypt_password(str(uuid.uuid1())[:8])
258 def_user.name = 'Anonymous'
262 def_user.name = 'Anonymous'
259 def_user.lastname = 'User'
263 def_user.lastname = 'User'
260 def_user.email = 'anonymous@rhodecode.org'
264 def_user.email = 'anonymous@rhodecode.org'
261 def_user.admin = False
265 def_user.admin = False
262 def_user.active = False
266 def_user.active = False
263 try:
267 try:
264 self.sa.add(def_user)
268 self.sa.add(def_user)
265 self.sa.commit()
269 self.sa.commit()
266 except:
270 except:
267 self.sa.rollback()
271 self.sa.rollback()
268 raise
272 raise
269
273
270 def create_permissions(self):
274 def create_permissions(self):
271 #module.(access|create|change|delete)_[name]
275 #module.(access|create|change|delete)_[name]
272 #module.(read|write|owner)
276 #module.(read|write|owner)
273 perms = [('repository.none', 'Repository no access'),
277 perms = [('repository.none', 'Repository no access'),
274 ('repository.read', 'Repository read access'),
278 ('repository.read', 'Repository read access'),
275 ('repository.write', 'Repository write access'),
279 ('repository.write', 'Repository write access'),
276 ('repository.admin', 'Repository admin access'),
280 ('repository.admin', 'Repository admin access'),
277 ('hg.admin', 'Hg Administrator'),
281 ('hg.admin', 'Hg Administrator'),
278 ('hg.create.repository', 'Repository create'),
282 ('hg.create.repository', 'Repository create'),
279 ('hg.create.none', 'Repository creation disabled'),
283 ('hg.create.none', 'Repository creation disabled'),
280 ('hg.register.none', 'Register disabled'),
284 ('hg.register.none', 'Register disabled'),
281 ('hg.register.manual_activate', 'Register new user with rhodecode without manual activation'),
285 ('hg.register.manual_activate', 'Register new user with rhodecode without manual activation'),
282 ('hg.register.auto_activate', 'Register new user with rhodecode without auto activation'),
286 ('hg.register.auto_activate', 'Register new user with rhodecode without auto activation'),
283 ]
287 ]
284
288
285 for p in perms:
289 for p in perms:
286 new_perm = Permission()
290 new_perm = Permission()
287 new_perm.permission_name = p[0]
291 new_perm.permission_name = p[0]
288 new_perm.permission_longname = p[1]
292 new_perm.permission_longname = p[1]
289 try:
293 try:
290 self.sa.add(new_perm)
294 self.sa.add(new_perm)
291 self.sa.commit()
295 self.sa.commit()
292 except:
296 except:
293 self.sa.rollback()
297 self.sa.rollback()
294 raise
298 raise
295
299
296 def populate_default_permissions(self):
300 def populate_default_permissions(self):
297 log.info('creating default user permissions')
301 log.info('creating default user permissions')
298
302
299 default_user = self.sa.query(User)\
303 default_user = self.sa.query(User)\
300 .filter(User.username == 'default').scalar()
304 .filter(User.username == 'default').scalar()
301
305
302 reg_perm = UserToPerm()
306 reg_perm = UserToPerm()
303 reg_perm.user = default_user
307 reg_perm.user = default_user
304 reg_perm.permission = self.sa.query(Permission)\
308 reg_perm.permission = self.sa.query(Permission)\
305 .filter(Permission.permission_name == 'hg.register.manual_activate')\
309 .filter(Permission.permission_name == 'hg.register.manual_activate')\
306 .scalar()
310 .scalar()
307
311
308 create_repo_perm = UserToPerm()
312 create_repo_perm = UserToPerm()
309 create_repo_perm.user = default_user
313 create_repo_perm.user = default_user
310 create_repo_perm.permission = self.sa.query(Permission)\
314 create_repo_perm.permission = self.sa.query(Permission)\
311 .filter(Permission.permission_name == 'hg.create.repository')\
315 .filter(Permission.permission_name == 'hg.create.repository')\
312 .scalar()
316 .scalar()
313
317
314 default_repo_perm = UserToPerm()
318 default_repo_perm = UserToPerm()
315 default_repo_perm.user = default_user
319 default_repo_perm.user = default_user
316 default_repo_perm.permission = self.sa.query(Permission)\
320 default_repo_perm.permission = self.sa.query(Permission)\
317 .filter(Permission.permission_name == 'repository.read')\
321 .filter(Permission.permission_name == 'repository.read')\
318 .scalar()
322 .scalar()
319
323
320 try:
324 try:
321 self.sa.add(reg_perm)
325 self.sa.add(reg_perm)
322 self.sa.add(create_repo_perm)
326 self.sa.add(create_repo_perm)
323 self.sa.add(default_repo_perm)
327 self.sa.add(default_repo_perm)
324 self.sa.commit()
328 self.sa.commit()
325 except:
329 except:
326 self.sa.rollback()
330 self.sa.rollback()
327 raise
331 raise
328
332
@@ -1,59 +1,77 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 rhodecode.lib.dbmigrate.__init__
3 rhodecode.lib.dbmigrate.__init__
4 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5
5
6 Database migration modules
6 Database migration modules
7
7
8 :created_on: Dec 11, 2010
8 :created_on: Dec 11, 2010
9 :author: marcink
9 :author: marcink
10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
10 :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com>
11 :license: GPLv3, see COPYING for more details.
11 :license: GPLv3, see COPYING for more details.
12 """
12 """
13 # This program is free software; you can redistribute it and/or
13 # This program is free software; you can redistribute it and/or
14 # modify it under the terms of the GNU General Public License
14 # modify it under the terms of the GNU General Public License
15 # as published by the Free Software Foundation; version 2
15 # as published by the Free Software Foundation; version 2
16 # of the License or (at your opinion) any later version of the license.
16 # of the License or (at your opinion) any later version of the license.
17 #
17 #
18 # This program is distributed in the hope that it will be useful,
18 # This program is distributed in the hope that it will be useful,
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 # GNU General Public License for more details.
21 # GNU General Public License for more details.
22 #
22 #
23 # You should have received a copy of the GNU General Public License
23 # You should have received a copy of the GNU General Public License
24 # along with this program; if not, write to the Free Software
24 # along with this program; if not, write to the Free Software
25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
25 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
26 # MA 02110-1301, USA.
26 # MA 02110-1301, USA.
27
27
28 from rhodecode.lib.utils import BasePasterCommand
28 import logging
29 from sqlalchemy import engine_from_config
30
31 from rhodecode.lib.dbmigrate.migrate.exceptions import \
32 DatabaseNotControlledError
29 from rhodecode.lib.utils import BasePasterCommand, Command, add_cache
33 from rhodecode.lib.utils import BasePasterCommand, Command, add_cache
30
34
31 from sqlalchemy import engine_from_config
35 log = logging.getLogger(__name__)
32
36
33 class UpgradeDb(BasePasterCommand):
37 class UpgradeDb(BasePasterCommand):
34 """Command used for paster to upgrade our database to newer version
38 """Command used for paster to upgrade our database to newer version
35 """
39 """
36
40
37 max_args = 1
41 max_args = 1
38 min_args = 1
42 min_args = 1
39
43
40 usage = "CONFIG_FILE"
44 usage = "CONFIG_FILE"
41 summary = "Upgrades current db to newer version given configuration file"
45 summary = "Upgrades current db to newer version given configuration file"
42 group_name = "RhodeCode"
46 group_name = "RhodeCode"
43
47
44 parser = Command.standard_parser(verbose=True)
48 parser = Command.standard_parser(verbose=True)
45
49
46 def command(self):
50 def command(self):
47 from pylons import config
51 from pylons import config
48 add_cache(config)
52 add_cache(config)
49 engine = engine_from_config(config, 'sqlalchemy.db1.')
53 #engine = engine_from_config(config, 'sqlalchemy.db1.')
50 print engine
54 #rint engine
51 raise NotImplementedError('Not implemented yet')
55
56 from rhodecode.lib.dbmigrate.migrate.versioning import api
57 path = 'rhodecode/lib/dbmigrate'
58
52
59
60 try:
61 curr_version = api.db_version(config['sqlalchemy.db1.url'], path)
62 msg = ('Found current database under version'
63 ' control with version %s' % curr_version)
64
65 except (RuntimeError, DatabaseNotControlledError), e:
66 curr_version = 0
67 msg = ('Current database is not under version control setting'
68 ' as version %s' % curr_version)
69
70 print msg
53
71
54 def update_parser(self):
72 def update_parser(self):
55 self.parser.add_option('--sql',
73 self.parser.add_option('--sql',
56 action='store_true',
74 action='store_true',
57 dest='just_sql',
75 dest='just_sql',
58 help="Prints upgrade sql for further investigation",
76 help="Prints upgrade sql for further investigation",
59 default=False)
77 default=False)
@@ -1,9 +1,9 b''
1 """
1 """
2 SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for
2 SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for
3 database schema version and repository management and
3 database schema version and repository management and
4 :mod:`migrate.changeset` that allows to define database schema changes
4 :mod:`migrate.changeset` that allows to define database schema changes
5 using Python.
5 using Python.
6 """
6 """
7
7
8 from migrate.versioning import *
8 from rhodecode.lib.dbmigrate.migrate.versioning import *
9 from migrate.changeset import *
9 from rhodecode.lib.dbmigrate.migrate.changeset import *
@@ -1,28 +1,28 b''
1 """
1 """
2 This module extends SQLAlchemy and provides additional DDL [#]_
2 This module extends SQLAlchemy and provides additional DDL [#]_
3 support.
3 support.
4
4
5 .. [#] SQL Data Definition Language
5 .. [#] SQL Data Definition Language
6 """
6 """
7 import re
7 import re
8 import warnings
8 import warnings
9
9
10 import sqlalchemy
10 import sqlalchemy
11 from sqlalchemy import __version__ as _sa_version
11 from sqlalchemy import __version__ as _sa_version
12
12
13 warnings.simplefilter('always', DeprecationWarning)
13 warnings.simplefilter('always', DeprecationWarning)
14
14
15 _sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split("."))
15 _sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split("."))
16 SQLA_06 = _sa_version >= (0, 6)
16 SQLA_06 = _sa_version >= (0, 6)
17
17
18 del re
18 del re
19 del _sa_version
19 del _sa_version
20
20
21 from migrate.changeset.schema import *
21 from rhodecode.lib.dbmigrate.migrate.changeset.schema import *
22 from migrate.changeset.constraint import *
22 from rhodecode.lib.dbmigrate.migrate.changeset.constraint import *
23
23
24 sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
24 sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
25 sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
25 sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
26 sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
26 sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
27
27
28 sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause, )
28 sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause, )
@@ -1,358 +1,358 b''
1 """
1 """
2 Extensions to SQLAlchemy for altering existing tables.
2 Extensions to SQLAlchemy for altering existing tables.
3
3
4 At the moment, this isn't so much based off of ANSI as much as
4 At the moment, this isn't so much based off of ANSI as much as
5 things that just happen to work with multiple databases.
5 things that just happen to work with multiple databases.
6 """
6 """
7 import StringIO
7 import StringIO
8
8
9 import sqlalchemy as sa
9 import sqlalchemy as sa
10 from sqlalchemy.schema import SchemaVisitor
10 from sqlalchemy.schema import SchemaVisitor
11 from sqlalchemy.engine.default import DefaultDialect
11 from sqlalchemy.engine.default import DefaultDialect
12 from sqlalchemy.sql import ClauseElement
12 from sqlalchemy.sql import ClauseElement
13 from sqlalchemy.schema import (ForeignKeyConstraint,
13 from sqlalchemy.schema import (ForeignKeyConstraint,
14 PrimaryKeyConstraint,
14 PrimaryKeyConstraint,
15 CheckConstraint,
15 CheckConstraint,
16 UniqueConstraint,
16 UniqueConstraint,
17 Index)
17 Index)
18
18
19 from migrate import exceptions
19 from rhodecode.lib.dbmigrate.migrate import exceptions
20 from migrate.changeset import constraint, SQLA_06
20 from rhodecode.lib.dbmigrate.migrate.changeset import constraint, SQLA_06
21
21
22 if not SQLA_06:
22 if not SQLA_06:
23 from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
23 from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
24 else:
24 else:
25 from sqlalchemy.schema import AddConstraint, DropConstraint
25 from sqlalchemy.schema import AddConstraint, DropConstraint
26 from sqlalchemy.sql.compiler import DDLCompiler
26 from sqlalchemy.sql.compiler import DDLCompiler
27 SchemaGenerator = SchemaDropper = DDLCompiler
27 SchemaGenerator = SchemaDropper = DDLCompiler
28
28
29
29
30 class AlterTableVisitor(SchemaVisitor):
30 class AlterTableVisitor(SchemaVisitor):
31 """Common operations for ``ALTER TABLE`` statements."""
31 """Common operations for ``ALTER TABLE`` statements."""
32
32
33 if SQLA_06:
33 if SQLA_06:
34 # engine.Compiler looks for .statement
34 # engine.Compiler looks for .statement
35 # when it spawns off a new compiler
35 # when it spawns off a new compiler
36 statement = ClauseElement()
36 statement = ClauseElement()
37
37
38 def append(self, s):
38 def append(self, s):
39 """Append content to the SchemaIterator's query buffer."""
39 """Append content to the SchemaIterator's query buffer."""
40
40
41 self.buffer.write(s)
41 self.buffer.write(s)
42
42
43 def execute(self):
43 def execute(self):
44 """Execute the contents of the SchemaIterator's buffer."""
44 """Execute the contents of the SchemaIterator's buffer."""
45 try:
45 try:
46 return self.connection.execute(self.buffer.getvalue())
46 return self.connection.execute(self.buffer.getvalue())
47 finally:
47 finally:
48 self.buffer.truncate(0)
48 self.buffer.truncate(0)
49
49
50 def __init__(self, dialect, connection, **kw):
50 def __init__(self, dialect, connection, **kw):
51 self.connection = connection
51 self.connection = connection
52 self.buffer = StringIO.StringIO()
52 self.buffer = StringIO.StringIO()
53 self.preparer = dialect.identifier_preparer
53 self.preparer = dialect.identifier_preparer
54 self.dialect = dialect
54 self.dialect = dialect
55
55
56 def traverse_single(self, elem):
56 def traverse_single(self, elem):
57 ret = super(AlterTableVisitor, self).traverse_single(elem)
57 ret = super(AlterTableVisitor, self).traverse_single(elem)
58 if ret:
58 if ret:
59 # adapt to 0.6 which uses a string-returning
59 # adapt to 0.6 which uses a string-returning
60 # object
60 # object
61 self.append(" %s" % ret)
61 self.append(" %s" % ret)
62
62
63 def _to_table(self, param):
63 def _to_table(self, param):
64 """Returns the table object for the given param object."""
64 """Returns the table object for the given param object."""
65 if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
65 if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
66 ret = param.table
66 ret = param.table
67 else:
67 else:
68 ret = param
68 ret = param
69 return ret
69 return ret
70
70
71 def start_alter_table(self, param):
71 def start_alter_table(self, param):
72 """Returns the start of an ``ALTER TABLE`` SQL-Statement.
72 """Returns the start of an ``ALTER TABLE`` SQL-Statement.
73
73
74 Use the param object to determine the table name and use it
74 Use the param object to determine the table name and use it
75 for building the SQL statement.
75 for building the SQL statement.
76
76
77 :param param: object to determine the table from
77 :param param: object to determine the table from
78 :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
78 :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
79 :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
79 :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
80 or string (table name)
80 or string (table name)
81 """
81 """
82 table = self._to_table(param)
82 table = self._to_table(param)
83 self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
83 self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
84 return table
84 return table
85
85
86
86
87 class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
87 class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
88 """Extends ansisql generator for column creation (alter table add col)"""
88 """Extends ansisql generator for column creation (alter table add col)"""
89
89
90 def visit_column(self, column):
90 def visit_column(self, column):
91 """Create a column (table already exists).
91 """Create a column (table already exists).
92
92
93 :param column: column object
93 :param column: column object
94 :type column: :class:`sqlalchemy.Column` instance
94 :type column: :class:`sqlalchemy.Column` instance
95 """
95 """
96 if column.default is not None:
96 if column.default is not None:
97 self.traverse_single(column.default)
97 self.traverse_single(column.default)
98
98
99 table = self.start_alter_table(column)
99 table = self.start_alter_table(column)
100 self.append("ADD ")
100 self.append("ADD ")
101 self.append(self.get_column_specification(column))
101 self.append(self.get_column_specification(column))
102
102
103 for cons in column.constraints:
103 for cons in column.constraints:
104 self.traverse_single(cons)
104 self.traverse_single(cons)
105 self.execute()
105 self.execute()
106
106
107 # ALTER TABLE STATEMENTS
107 # ALTER TABLE STATEMENTS
108
108
109 # add indexes and unique constraints
109 # add indexes and unique constraints
110 if column.index_name:
110 if column.index_name:
111 Index(column.index_name,column).create()
111 Index(column.index_name,column).create()
112 elif column.unique_name:
112 elif column.unique_name:
113 constraint.UniqueConstraint(column,
113 constraint.UniqueConstraint(column,
114 name=column.unique_name).create()
114 name=column.unique_name).create()
115
115
116 # SA bounds FK constraints to table, add manually
116 # SA bounds FK constraints to table, add manually
117 for fk in column.foreign_keys:
117 for fk in column.foreign_keys:
118 self.add_foreignkey(fk.constraint)
118 self.add_foreignkey(fk.constraint)
119
119
120 # add primary key constraint if needed
120 # add primary key constraint if needed
121 if column.primary_key_name:
121 if column.primary_key_name:
122 cons = constraint.PrimaryKeyConstraint(column,
122 cons = constraint.PrimaryKeyConstraint(column,
123 name=column.primary_key_name)
123 name=column.primary_key_name)
124 cons.create()
124 cons.create()
125
125
126 if SQLA_06:
126 if SQLA_06:
127 def add_foreignkey(self, fk):
127 def add_foreignkey(self, fk):
128 self.connection.execute(AddConstraint(fk))
128 self.connection.execute(AddConstraint(fk))
129
129
130 class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
130 class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
131 """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
131 """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
132 DROP COLUMN``).
132 DROP COLUMN``).
133 """
133 """
134
134
135 def visit_column(self, column):
135 def visit_column(self, column):
136 """Drop a column from its table.
136 """Drop a column from its table.
137
137
138 :param column: the column object
138 :param column: the column object
139 :type column: :class:`sqlalchemy.Column`
139 :type column: :class:`sqlalchemy.Column`
140 """
140 """
141 table = self.start_alter_table(column)
141 table = self.start_alter_table(column)
142 self.append('DROP COLUMN %s' % self.preparer.format_column(column))
142 self.append('DROP COLUMN %s' % self.preparer.format_column(column))
143 self.execute()
143 self.execute()
144
144
145
145
146 class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
146 class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
147 """Manages changes to existing schema elements.
147 """Manages changes to existing schema elements.
148
148
149 Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
149 Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
150 is in SchemaGenerator.
150 is in SchemaGenerator.
151
151
152 All items may be renamed. Columns can also have many of their properties -
152 All items may be renamed. Columns can also have many of their properties -
153 type, for example - changed.
153 type, for example - changed.
154
154
155 Each function is passed a tuple, containing (object, name); where
155 Each function is passed a tuple, containing (object, name); where
156 object is a type of object you'd expect for that function
156 object is a type of object you'd expect for that function
157 (ie. table for visit_table) and name is the object's new
157 (ie. table for visit_table) and name is the object's new
158 name. NONE means the name is unchanged.
158 name. NONE means the name is unchanged.
159 """
159 """
160
160
161 def visit_table(self, table):
161 def visit_table(self, table):
162 """Rename a table. Other ops aren't supported."""
162 """Rename a table. Other ops aren't supported."""
163 self.start_alter_table(table)
163 self.start_alter_table(table)
164 self.append("RENAME TO %s" % self.preparer.quote(table.new_name,
164 self.append("RENAME TO %s" % self.preparer.quote(table.new_name,
165 table.quote))
165 table.quote))
166 self.execute()
166 self.execute()
167
167
168 def visit_index(self, index):
168 def visit_index(self, index):
169 """Rename an index"""
169 """Rename an index"""
170 if hasattr(self, '_validate_identifier'):
170 if hasattr(self, '_validate_identifier'):
171 # SA <= 0.6.3
171 # SA <= 0.6.3
172 self.append("ALTER INDEX %s RENAME TO %s" % (
172 self.append("ALTER INDEX %s RENAME TO %s" % (
173 self.preparer.quote(
173 self.preparer.quote(
174 self._validate_identifier(
174 self._validate_identifier(
175 index.name, True), index.quote),
175 index.name, True), index.quote),
176 self.preparer.quote(
176 self.preparer.quote(
177 self._validate_identifier(
177 self._validate_identifier(
178 index.new_name, True), index.quote)))
178 index.new_name, True), index.quote)))
179 else:
179 else:
180 # SA >= 0.6.5
180 # SA >= 0.6.5
181 self.append("ALTER INDEX %s RENAME TO %s" % (
181 self.append("ALTER INDEX %s RENAME TO %s" % (
182 self.preparer.quote(
182 self.preparer.quote(
183 self._index_identifier(
183 self._index_identifier(
184 index.name), index.quote),
184 index.name), index.quote),
185 self.preparer.quote(
185 self.preparer.quote(
186 self._index_identifier(
186 self._index_identifier(
187 index.new_name), index.quote)))
187 index.new_name), index.quote)))
188 self.execute()
188 self.execute()
189
189
190 def visit_column(self, delta):
190 def visit_column(self, delta):
191 """Rename/change a column."""
191 """Rename/change a column."""
192 # ALTER COLUMN is implemented as several ALTER statements
192 # ALTER COLUMN is implemented as several ALTER statements
193 keys = delta.keys()
193 keys = delta.keys()
194 if 'type' in keys:
194 if 'type' in keys:
195 self._run_subvisit(delta, self._visit_column_type)
195 self._run_subvisit(delta, self._visit_column_type)
196 if 'nullable' in keys:
196 if 'nullable' in keys:
197 self._run_subvisit(delta, self._visit_column_nullable)
197 self._run_subvisit(delta, self._visit_column_nullable)
198 if 'server_default' in keys:
198 if 'server_default' in keys:
199 # Skip 'default': only handle server-side defaults, others
199 # Skip 'default': only handle server-side defaults, others
200 # are managed by the app, not the db.
200 # are managed by the app, not the db.
201 self._run_subvisit(delta, self._visit_column_default)
201 self._run_subvisit(delta, self._visit_column_default)
202 if 'name' in keys:
202 if 'name' in keys:
203 self._run_subvisit(delta, self._visit_column_name, start_alter=False)
203 self._run_subvisit(delta, self._visit_column_name, start_alter=False)
204
204
205 def _run_subvisit(self, delta, func, start_alter=True):
205 def _run_subvisit(self, delta, func, start_alter=True):
206 """Runs visit method based on what needs to be changed on column"""
206 """Runs visit method based on what needs to be changed on column"""
207 table = self._to_table(delta.table)
207 table = self._to_table(delta.table)
208 col_name = delta.current_name
208 col_name = delta.current_name
209 if start_alter:
209 if start_alter:
210 self.start_alter_column(table, col_name)
210 self.start_alter_column(table, col_name)
211 ret = func(table, delta.result_column, delta)
211 ret = func(table, delta.result_column, delta)
212 self.execute()
212 self.execute()
213
213
214 def start_alter_column(self, table, col_name):
214 def start_alter_column(self, table, col_name):
215 """Starts ALTER COLUMN"""
215 """Starts ALTER COLUMN"""
216 self.start_alter_table(table)
216 self.start_alter_table(table)
217 self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, table.quote))
217 self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, table.quote))
218
218
219 def _visit_column_nullable(self, table, column, delta):
219 def _visit_column_nullable(self, table, column, delta):
220 nullable = delta['nullable']
220 nullable = delta['nullable']
221 if nullable:
221 if nullable:
222 self.append("DROP NOT NULL")
222 self.append("DROP NOT NULL")
223 else:
223 else:
224 self.append("SET NOT NULL")
224 self.append("SET NOT NULL")
225
225
226 def _visit_column_default(self, table, column, delta):
226 def _visit_column_default(self, table, column, delta):
227 default_text = self.get_column_default_string(column)
227 default_text = self.get_column_default_string(column)
228 if default_text is not None:
228 if default_text is not None:
229 self.append("SET DEFAULT %s" % default_text)
229 self.append("SET DEFAULT %s" % default_text)
230 else:
230 else:
231 self.append("DROP DEFAULT")
231 self.append("DROP DEFAULT")
232
232
233 def _visit_column_type(self, table, column, delta):
233 def _visit_column_type(self, table, column, delta):
234 type_ = delta['type']
234 type_ = delta['type']
235 if SQLA_06:
235 if SQLA_06:
236 type_text = str(type_.compile(dialect=self.dialect))
236 type_text = str(type_.compile(dialect=self.dialect))
237 else:
237 else:
238 type_text = type_.dialect_impl(self.dialect).get_col_spec()
238 type_text = type_.dialect_impl(self.dialect).get_col_spec()
239 self.append("TYPE %s" % type_text)
239 self.append("TYPE %s" % type_text)
240
240
241 def _visit_column_name(self, table, column, delta):
241 def _visit_column_name(self, table, column, delta):
242 self.start_alter_table(table)
242 self.start_alter_table(table)
243 col_name = self.preparer.quote(delta.current_name, table.quote)
243 col_name = self.preparer.quote(delta.current_name, table.quote)
244 new_name = self.preparer.format_column(delta.result_column)
244 new_name = self.preparer.format_column(delta.result_column)
245 self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
245 self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
246
246
247
247
248 class ANSIConstraintCommon(AlterTableVisitor):
248 class ANSIConstraintCommon(AlterTableVisitor):
249 """
249 """
250 Migrate's constraints require a separate creation function from
250 Migrate's constraints require a separate creation function from
251 SA's: Migrate's constraints are created independently of a table;
251 SA's: Migrate's constraints are created independently of a table;
252 SA's are created at the same time as the table.
252 SA's are created at the same time as the table.
253 """
253 """
254
254
255 def get_constraint_name(self, cons):
255 def get_constraint_name(self, cons):
256 """Gets a name for the given constraint.
256 """Gets a name for the given constraint.
257
257
258 If the name is already set it will be used otherwise the
258 If the name is already set it will be used otherwise the
259 constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
259 constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
260 method is used.
260 method is used.
261
261
262 :param cons: constraint object
262 :param cons: constraint object
263 """
263 """
264 if cons.name is not None:
264 if cons.name is not None:
265 ret = cons.name
265 ret = cons.name
266 else:
266 else:
267 ret = cons.name = cons.autoname()
267 ret = cons.name = cons.autoname()
268 return self.preparer.quote(ret, cons.quote)
268 return self.preparer.quote(ret, cons.quote)
269
269
270 def visit_migrate_primary_key_constraint(self, *p, **k):
270 def visit_migrate_primary_key_constraint(self, *p, **k):
271 self._visit_constraint(*p, **k)
271 self._visit_constraint(*p, **k)
272
272
273 def visit_migrate_foreign_key_constraint(self, *p, **k):
273 def visit_migrate_foreign_key_constraint(self, *p, **k):
274 self._visit_constraint(*p, **k)
274 self._visit_constraint(*p, **k)
275
275
276 def visit_migrate_check_constraint(self, *p, **k):
276 def visit_migrate_check_constraint(self, *p, **k):
277 self._visit_constraint(*p, **k)
277 self._visit_constraint(*p, **k)
278
278
279 def visit_migrate_unique_constraint(self, *p, **k):
279 def visit_migrate_unique_constraint(self, *p, **k):
280 self._visit_constraint(*p, **k)
280 self._visit_constraint(*p, **k)
281
281
282 if SQLA_06:
282 if SQLA_06:
283 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
283 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
284 def _visit_constraint(self, constraint):
284 def _visit_constraint(self, constraint):
285 constraint.name = self.get_constraint_name(constraint)
285 constraint.name = self.get_constraint_name(constraint)
286 self.append(self.process(AddConstraint(constraint)))
286 self.append(self.process(AddConstraint(constraint)))
287 self.execute()
287 self.execute()
288
288
289 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
289 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
290 def _visit_constraint(self, constraint):
290 def _visit_constraint(self, constraint):
291 constraint.name = self.get_constraint_name(constraint)
291 constraint.name = self.get_constraint_name(constraint)
292 self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
292 self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
293 self.execute()
293 self.execute()
294
294
295 else:
295 else:
296 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
296 class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
297
297
298 def get_constraint_specification(self, cons, **kwargs):
298 def get_constraint_specification(self, cons, **kwargs):
299 """Constaint SQL generators.
299 """Constaint SQL generators.
300
300
301 We cannot use SA visitors because they append comma.
301 We cannot use SA visitors because they append comma.
302 """
302 """
303
303
304 if isinstance(cons, PrimaryKeyConstraint):
304 if isinstance(cons, PrimaryKeyConstraint):
305 if cons.name is not None:
305 if cons.name is not None:
306 self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
306 self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
307 self.append("PRIMARY KEY ")
307 self.append("PRIMARY KEY ")
308 self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
308 self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
309 for c in cons))
309 for c in cons))
310 self.define_constraint_deferrability(cons)
310 self.define_constraint_deferrability(cons)
311 elif isinstance(cons, ForeignKeyConstraint):
311 elif isinstance(cons, ForeignKeyConstraint):
312 self.define_foreign_key(cons)
312 self.define_foreign_key(cons)
313 elif isinstance(cons, CheckConstraint):
313 elif isinstance(cons, CheckConstraint):
314 if cons.name is not None:
314 if cons.name is not None:
315 self.append("CONSTRAINT %s " %
315 self.append("CONSTRAINT %s " %
316 self.preparer.format_constraint(cons))
316 self.preparer.format_constraint(cons))
317 self.append("CHECK (%s)" % cons.sqltext)
317 self.append("CHECK (%s)" % cons.sqltext)
318 self.define_constraint_deferrability(cons)
318 self.define_constraint_deferrability(cons)
319 elif isinstance(cons, UniqueConstraint):
319 elif isinstance(cons, UniqueConstraint):
320 if cons.name is not None:
320 if cons.name is not None:
321 self.append("CONSTRAINT %s " %
321 self.append("CONSTRAINT %s " %
322 self.preparer.format_constraint(cons))
322 self.preparer.format_constraint(cons))
323 self.append("UNIQUE (%s)" % \
323 self.append("UNIQUE (%s)" % \
324 (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
324 (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
325 self.define_constraint_deferrability(cons)
325 self.define_constraint_deferrability(cons)
326 else:
326 else:
327 raise exceptions.InvalidConstraintError(cons)
327 raise exceptions.InvalidConstraintError(cons)
328
328
329 def _visit_constraint(self, constraint):
329 def _visit_constraint(self, constraint):
330
330
331 table = self.start_alter_table(constraint)
331 table = self.start_alter_table(constraint)
332 constraint.name = self.get_constraint_name(constraint)
332 constraint.name = self.get_constraint_name(constraint)
333 self.append("ADD ")
333 self.append("ADD ")
334 self.get_constraint_specification(constraint)
334 self.get_constraint_specification(constraint)
335 self.execute()
335 self.execute()
336
336
337
337
338 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
338 class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
339
339
340 def _visit_constraint(self, constraint):
340 def _visit_constraint(self, constraint):
341 self.start_alter_table(constraint)
341 self.start_alter_table(constraint)
342 self.append("DROP CONSTRAINT ")
342 self.append("DROP CONSTRAINT ")
343 constraint.name = self.get_constraint_name(constraint)
343 constraint.name = self.get_constraint_name(constraint)
344 self.append(self.preparer.format_constraint(constraint))
344 self.append(self.preparer.format_constraint(constraint))
345 if constraint.cascade:
345 if constraint.cascade:
346 self.cascade_constraint(constraint)
346 self.cascade_constraint(constraint)
347 self.execute()
347 self.execute()
348
348
349 def cascade_constraint(self, constraint):
349 def cascade_constraint(self, constraint):
350 self.append(" CASCADE")
350 self.append(" CASCADE")
351
351
352
352
353 class ANSIDialect(DefaultDialect):
353 class ANSIDialect(DefaultDialect):
354 columngenerator = ANSIColumnGenerator
354 columngenerator = ANSIColumnGenerator
355 columndropper = ANSIColumnDropper
355 columndropper = ANSIColumnDropper
356 schemachanger = ANSISchemaChanger
356 schemachanger = ANSISchemaChanger
357 constraintgenerator = ANSIConstraintGenerator
357 constraintgenerator = ANSIConstraintGenerator
358 constraintdropper = ANSIConstraintDropper
358 constraintdropper = ANSIConstraintDropper
@@ -1,202 +1,202 b''
1 """
1 """
2 This module defines standalone schema constraint classes.
2 This module defines standalone schema constraint classes.
3 """
3 """
4 from sqlalchemy import schema
4 from sqlalchemy import schema
5
5
6 from migrate.exceptions import *
6 from rhodecode.lib.dbmigrate.migrate.exceptions import *
7 from migrate.changeset import SQLA_06
7 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
8
8
9 class ConstraintChangeset(object):
9 class ConstraintChangeset(object):
10 """Base class for Constraint classes."""
10 """Base class for Constraint classes."""
11
11
12 def _normalize_columns(self, cols, table_name=False):
12 def _normalize_columns(self, cols, table_name=False):
13 """Given: column objects or names; return col names and
13 """Given: column objects or names; return col names and
14 (maybe) a table"""
14 (maybe) a table"""
15 colnames = []
15 colnames = []
16 table = None
16 table = None
17 for col in cols:
17 for col in cols:
18 if isinstance(col, schema.Column):
18 if isinstance(col, schema.Column):
19 if col.table is not None and table is None:
19 if col.table is not None and table is None:
20 table = col.table
20 table = col.table
21 if table_name:
21 if table_name:
22 col = '.'.join((col.table.name, col.name))
22 col = '.'.join((col.table.name, col.name))
23 else:
23 else:
24 col = col.name
24 col = col.name
25 colnames.append(col)
25 colnames.append(col)
26 return colnames, table
26 return colnames, table
27
27
28 def __do_imports(self, visitor_name, *a, **kw):
28 def __do_imports(self, visitor_name, *a, **kw):
29 engine = kw.pop('engine', self.table.bind)
29 engine = kw.pop('engine', self.table.bind)
30 from migrate.changeset.databases.visitor import (get_engine_visitor,
30 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor,
31 run_single_visitor)
31 run_single_visitor)
32 visitorcallable = get_engine_visitor(engine, visitor_name)
32 visitorcallable = get_engine_visitor(engine, visitor_name)
33 run_single_visitor(engine, visitorcallable, self, *a, **kw)
33 run_single_visitor(engine, visitorcallable, self, *a, **kw)
34
34
35 def create(self, *a, **kw):
35 def create(self, *a, **kw):
36 """Create the constraint in the database.
36 """Create the constraint in the database.
37
37
38 :param engine: the database engine to use. If this is \
38 :param engine: the database engine to use. If this is \
39 :keyword:`None` the instance's engine will be used
39 :keyword:`None` the instance's engine will be used
40 :type engine: :class:`sqlalchemy.engine.base.Engine`
40 :type engine: :class:`sqlalchemy.engine.base.Engine`
41 :param connection: reuse connection istead of creating new one.
41 :param connection: reuse connection istead of creating new one.
42 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
42 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
43 """
43 """
44 # TODO: set the parent here instead of in __init__
44 # TODO: set the parent here instead of in __init__
45 self.__do_imports('constraintgenerator', *a, **kw)
45 self.__do_imports('constraintgenerator', *a, **kw)
46
46
47 def drop(self, *a, **kw):
47 def drop(self, *a, **kw):
48 """Drop the constraint from the database.
48 """Drop the constraint from the database.
49
49
50 :param engine: the database engine to use. If this is
50 :param engine: the database engine to use. If this is
51 :keyword:`None` the instance's engine will be used
51 :keyword:`None` the instance's engine will be used
52 :param cascade: Issue CASCADE drop if database supports it
52 :param cascade: Issue CASCADE drop if database supports it
53 :type engine: :class:`sqlalchemy.engine.base.Engine`
53 :type engine: :class:`sqlalchemy.engine.base.Engine`
54 :type cascade: bool
54 :type cascade: bool
55 :param connection: reuse connection istead of creating new one.
55 :param connection: reuse connection istead of creating new one.
56 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
56 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
57 :returns: Instance with cleared columns
57 :returns: Instance with cleared columns
58 """
58 """
59 self.cascade = kw.pop('cascade', False)
59 self.cascade = kw.pop('cascade', False)
60 self.__do_imports('constraintdropper', *a, **kw)
60 self.__do_imports('constraintdropper', *a, **kw)
61 # the spirit of Constraint objects is that they
61 # the spirit of Constraint objects is that they
62 # are immutable (just like in a DB. they're only ADDed
62 # are immutable (just like in a DB. they're only ADDed
63 # or DROPped).
63 # or DROPped).
64 #self.columns.clear()
64 #self.columns.clear()
65 return self
65 return self
66
66
67
67
68 class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
68 class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
69 """Construct PrimaryKeyConstraint
69 """Construct PrimaryKeyConstraint
70
70
71 Migrate's additional parameters:
71 Migrate's additional parameters:
72
72
73 :param cols: Columns in constraint.
73 :param cols: Columns in constraint.
74 :param table: If columns are passed as strings, this kw is required
74 :param table: If columns are passed as strings, this kw is required
75 :type table: Table instance
75 :type table: Table instance
76 :type cols: strings or Column instances
76 :type cols: strings or Column instances
77 """
77 """
78
78
79 __migrate_visit_name__ = 'migrate_primary_key_constraint'
79 __migrate_visit_name__ = 'migrate_primary_key_constraint'
80
80
81 def __init__(self, *cols, **kwargs):
81 def __init__(self, *cols, **kwargs):
82 colnames, table = self._normalize_columns(cols)
82 colnames, table = self._normalize_columns(cols)
83 table = kwargs.pop('table', table)
83 table = kwargs.pop('table', table)
84 super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs)
84 super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs)
85 if table is not None:
85 if table is not None:
86 self._set_parent(table)
86 self._set_parent(table)
87
87
88
88
89 def autoname(self):
89 def autoname(self):
90 """Mimic the database's automatic constraint names"""
90 """Mimic the database's automatic constraint names"""
91 return "%s_pkey" % self.table.name
91 return "%s_pkey" % self.table.name
92
92
93
93
94 class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
94 class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
95 """Construct ForeignKeyConstraint
95 """Construct ForeignKeyConstraint
96
96
97 Migrate's additional parameters:
97 Migrate's additional parameters:
98
98
99 :param columns: Columns in constraint
99 :param columns: Columns in constraint
100 :param refcolumns: Columns that this FK reffers to in another table.
100 :param refcolumns: Columns that this FK reffers to in another table.
101 :param table: If columns are passed as strings, this kw is required
101 :param table: If columns are passed as strings, this kw is required
102 :type table: Table instance
102 :type table: Table instance
103 :type columns: list of strings or Column instances
103 :type columns: list of strings or Column instances
104 :type refcolumns: list of strings or Column instances
104 :type refcolumns: list of strings or Column instances
105 """
105 """
106
106
107 __migrate_visit_name__ = 'migrate_foreign_key_constraint'
107 __migrate_visit_name__ = 'migrate_foreign_key_constraint'
108
108
109 def __init__(self, columns, refcolumns, *args, **kwargs):
109 def __init__(self, columns, refcolumns, *args, **kwargs):
110 colnames, table = self._normalize_columns(columns)
110 colnames, table = self._normalize_columns(columns)
111 table = kwargs.pop('table', table)
111 table = kwargs.pop('table', table)
112 refcolnames, reftable = self._normalize_columns(refcolumns,
112 refcolnames, reftable = self._normalize_columns(refcolumns,
113 table_name=True)
113 table_name=True)
114 super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *args,
114 super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *args,
115 **kwargs)
115 **kwargs)
116 if table is not None:
116 if table is not None:
117 self._set_parent(table)
117 self._set_parent(table)
118
118
119 @property
119 @property
120 def referenced(self):
120 def referenced(self):
121 return [e.column for e in self.elements]
121 return [e.column for e in self.elements]
122
122
123 @property
123 @property
124 def reftable(self):
124 def reftable(self):
125 return self.referenced[0].table
125 return self.referenced[0].table
126
126
127 def autoname(self):
127 def autoname(self):
128 """Mimic the database's automatic constraint names"""
128 """Mimic the database's automatic constraint names"""
129 if hasattr(self.columns, 'keys'):
129 if hasattr(self.columns, 'keys'):
130 # SA <= 0.5
130 # SA <= 0.5
131 firstcol = self.columns[self.columns.keys()[0]]
131 firstcol = self.columns[self.columns.keys()[0]]
132 ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
132 ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
133 table=firstcol.table.name,
133 table=firstcol.table.name,
134 firstcolumn=firstcol.name,)
134 firstcolumn=firstcol.name,)
135 else:
135 else:
136 # SA >= 0.6
136 # SA >= 0.6
137 ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
137 ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
138 table=self.table.name,
138 table=self.table.name,
139 firstcolumn=self.columns[0],)
139 firstcolumn=self.columns[0],)
140 return ret
140 return ret
141
141
142
142
143 class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
143 class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
144 """Construct CheckConstraint
144 """Construct CheckConstraint
145
145
146 Migrate's additional parameters:
146 Migrate's additional parameters:
147
147
148 :param sqltext: Plain SQL text to check condition
148 :param sqltext: Plain SQL text to check condition
149 :param columns: If not name is applied, you must supply this kw\
149 :param columns: If not name is applied, you must supply this kw\
150 to autoname constraint
150 to autoname constraint
151 :param table: If columns are passed as strings, this kw is required
151 :param table: If columns are passed as strings, this kw is required
152 :type table: Table instance
152 :type table: Table instance
153 :type columns: list of Columns instances
153 :type columns: list of Columns instances
154 :type sqltext: string
154 :type sqltext: string
155 """
155 """
156
156
157 __migrate_visit_name__ = 'migrate_check_constraint'
157 __migrate_visit_name__ = 'migrate_check_constraint'
158
158
159 def __init__(self, sqltext, *args, **kwargs):
159 def __init__(self, sqltext, *args, **kwargs):
160 cols = kwargs.pop('columns', [])
160 cols = kwargs.pop('columns', [])
161 if not cols and not kwargs.get('name', False):
161 if not cols and not kwargs.get('name', False):
162 raise InvalidConstraintError('You must either set "name"'
162 raise InvalidConstraintError('You must either set "name"'
163 'parameter or "columns" to autogenarate it.')
163 'parameter or "columns" to autogenarate it.')
164 colnames, table = self._normalize_columns(cols)
164 colnames, table = self._normalize_columns(cols)
165 table = kwargs.pop('table', table)
165 table = kwargs.pop('table', table)
166 schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
166 schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
167 if table is not None:
167 if table is not None:
168 if not SQLA_06:
168 if not SQLA_06:
169 self.table = table
169 self.table = table
170 self._set_parent(table)
170 self._set_parent(table)
171 self.colnames = colnames
171 self.colnames = colnames
172
172
173 def autoname(self):
173 def autoname(self):
174 return "%(table)s_%(cols)s_check" % \
174 return "%(table)s_%(cols)s_check" % \
175 dict(table=self.table.name, cols="_".join(self.colnames))
175 dict(table=self.table.name, cols="_".join(self.colnames))
176
176
177
177
178 class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
178 class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
179 """Construct UniqueConstraint
179 """Construct UniqueConstraint
180
180
181 Migrate's additional parameters:
181 Migrate's additional parameters:
182
182
183 :param cols: Columns in constraint.
183 :param cols: Columns in constraint.
184 :param table: If columns are passed as strings, this kw is required
184 :param table: If columns are passed as strings, this kw is required
185 :type table: Table instance
185 :type table: Table instance
186 :type cols: strings or Column instances
186 :type cols: strings or Column instances
187
187
188 .. versionadded:: 0.6.0
188 .. versionadded:: 0.6.0
189 """
189 """
190
190
191 __migrate_visit_name__ = 'migrate_unique_constraint'
191 __migrate_visit_name__ = 'migrate_unique_constraint'
192
192
193 def __init__(self, *cols, **kwargs):
193 def __init__(self, *cols, **kwargs):
194 self.colnames, table = self._normalize_columns(cols)
194 self.colnames, table = self._normalize_columns(cols)
195 table = kwargs.pop('table', table)
195 table = kwargs.pop('table', table)
196 super(UniqueConstraint, self).__init__(*self.colnames, **kwargs)
196 super(UniqueConstraint, self).__init__(*self.colnames, **kwargs)
197 if table is not None:
197 if table is not None:
198 self._set_parent(table)
198 self._set_parent(table)
199
199
200 def autoname(self):
200 def autoname(self):
201 """Mimic the database's automatic constraint names"""
201 """Mimic the database's automatic constraint names"""
202 return "%s_%s_key" % (self.table.name, self.colnames[0])
202 return "%s_%s_key" % (self.table.name, self.colnames[0])
@@ -1,80 +1,80 b''
1 """
1 """
2 Firebird database specific implementations of changeset classes.
2 Firebird database specific implementations of changeset classes.
3 """
3 """
4 from sqlalchemy.databases import firebird as sa_base
4 from sqlalchemy.databases import firebird as sa_base
5
5
6 from migrate import exceptions
6 from rhodecode.lib.dbmigrate.migrate import exceptions
7 from migrate.changeset import ansisql, SQLA_06
7 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
8
8
9
9
10 if SQLA_06:
10 if SQLA_06:
11 FBSchemaGenerator = sa_base.FBDDLCompiler
11 FBSchemaGenerator = sa_base.FBDDLCompiler
12 else:
12 else:
13 FBSchemaGenerator = sa_base.FBSchemaGenerator
13 FBSchemaGenerator = sa_base.FBSchemaGenerator
14
14
15 class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
15 class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
16 """Firebird column generator implementation."""
16 """Firebird column generator implementation."""
17
17
18
18
19 class FBColumnDropper(ansisql.ANSIColumnDropper):
19 class FBColumnDropper(ansisql.ANSIColumnDropper):
20 """Firebird column dropper implementation."""
20 """Firebird column dropper implementation."""
21
21
22 def visit_column(self, column):
22 def visit_column(self, column):
23 """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax
23 """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax
24
24
25 Drop primary key and unique constraints if dropped column is referencing it."""
25 Drop primary key and unique constraints if dropped column is referencing it."""
26 if column.primary_key:
26 if column.primary_key:
27 if column.table.primary_key.columns.contains_column(column):
27 if column.table.primary_key.columns.contains_column(column):
28 column.table.primary_key.drop()
28 column.table.primary_key.drop()
29 # TODO: recreate primary key if it references more than this column
29 # TODO: recreate primary key if it references more than this column
30 if column.unique or getattr(column, 'unique_name', None):
30 if column.unique or getattr(column, 'unique_name', None):
31 for cons in column.table.constraints:
31 for cons in column.table.constraints:
32 if cons.contains_column(column):
32 if cons.contains_column(column):
33 cons.drop()
33 cons.drop()
34 # TODO: recreate unique constraint if it refenrences more than this column
34 # TODO: recreate unique constraint if it refenrences more than this column
35
35
36 table = self.start_alter_table(column)
36 table = self.start_alter_table(column)
37 self.append('DROP %s' % self.preparer.format_column(column))
37 self.append('DROP %s' % self.preparer.format_column(column))
38 self.execute()
38 self.execute()
39
39
40
40
41 class FBSchemaChanger(ansisql.ANSISchemaChanger):
41 class FBSchemaChanger(ansisql.ANSISchemaChanger):
42 """Firebird schema changer implementation."""
42 """Firebird schema changer implementation."""
43
43
44 def visit_table(self, table):
44 def visit_table(self, table):
45 """Rename table not supported"""
45 """Rename table not supported"""
46 raise exceptions.NotSupportedError(
46 raise exceptions.NotSupportedError(
47 "Firebird does not support renaming tables.")
47 "Firebird does not support renaming tables.")
48
48
49 def _visit_column_name(self, table, column, delta):
49 def _visit_column_name(self, table, column, delta):
50 self.start_alter_table(table)
50 self.start_alter_table(table)
51 col_name = self.preparer.quote(delta.current_name, table.quote)
51 col_name = self.preparer.quote(delta.current_name, table.quote)
52 new_name = self.preparer.format_column(delta.result_column)
52 new_name = self.preparer.format_column(delta.result_column)
53 self.append('ALTER COLUMN %s TO %s' % (col_name, new_name))
53 self.append('ALTER COLUMN %s TO %s' % (col_name, new_name))
54
54
55 def _visit_column_nullable(self, table, column, delta):
55 def _visit_column_nullable(self, table, column, delta):
56 """Changing NULL is not supported"""
56 """Changing NULL is not supported"""
57 # TODO: http://www.firebirdfaq.org/faq103/
57 # TODO: http://www.firebirdfaq.org/faq103/
58 raise exceptions.NotSupportedError(
58 raise exceptions.NotSupportedError(
59 "Firebird does not support altering NULL bevahior.")
59 "Firebird does not support altering NULL bevahior.")
60
60
61
61
62 class FBConstraintGenerator(ansisql.ANSIConstraintGenerator):
62 class FBConstraintGenerator(ansisql.ANSIConstraintGenerator):
63 """Firebird constraint generator implementation."""
63 """Firebird constraint generator implementation."""
64
64
65
65
66 class FBConstraintDropper(ansisql.ANSIConstraintDropper):
66 class FBConstraintDropper(ansisql.ANSIConstraintDropper):
67 """Firebird constaint dropper implementation."""
67 """Firebird constaint dropper implementation."""
68
68
69 def cascade_constraint(self, constraint):
69 def cascade_constraint(self, constraint):
70 """Cascading constraints is not supported"""
70 """Cascading constraints is not supported"""
71 raise exceptions.NotSupportedError(
71 raise exceptions.NotSupportedError(
72 "Firebird does not support cascading constraints")
72 "Firebird does not support cascading constraints")
73
73
74
74
75 class FBDialect(ansisql.ANSIDialect):
75 class FBDialect(ansisql.ANSIDialect):
76 columngenerator = FBColumnGenerator
76 columngenerator = FBColumnGenerator
77 columndropper = FBColumnDropper
77 columndropper = FBColumnDropper
78 schemachanger = FBSchemaChanger
78 schemachanger = FBSchemaChanger
79 constraintgenerator = FBConstraintGenerator
79 constraintgenerator = FBConstraintGenerator
80 constraintdropper = FBConstraintDropper
80 constraintdropper = FBConstraintDropper
@@ -1,94 +1,94 b''
1 """
1 """
2 MySQL database specific implementations of changeset classes.
2 MySQL database specific implementations of changeset classes.
3 """
3 """
4
4
5 from sqlalchemy.databases import mysql as sa_base
5 from sqlalchemy.databases import mysql as sa_base
6 from sqlalchemy import types as sqltypes
6 from sqlalchemy import types as sqltypes
7
7
8 from migrate import exceptions
8 from rhodecode.lib.dbmigrate.migrate import exceptions
9 from migrate.changeset import ansisql, SQLA_06
9 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
10
10
11
11
12 if not SQLA_06:
12 if not SQLA_06:
13 MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator
13 MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator
14 else:
14 else:
15 MySQLSchemaGenerator = sa_base.MySQLDDLCompiler
15 MySQLSchemaGenerator = sa_base.MySQLDDLCompiler
16
16
17 class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator):
17 class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator):
18 pass
18 pass
19
19
20
20
21 class MySQLColumnDropper(ansisql.ANSIColumnDropper):
21 class MySQLColumnDropper(ansisql.ANSIColumnDropper):
22 pass
22 pass
23
23
24
24
25 class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger):
25 class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger):
26
26
27 def visit_column(self, delta):
27 def visit_column(self, delta):
28 table = delta.table
28 table = delta.table
29 colspec = self.get_column_specification(delta.result_column)
29 colspec = self.get_column_specification(delta.result_column)
30 if delta.result_column.autoincrement:
30 if delta.result_column.autoincrement:
31 primary_keys = [c for c in table.primary_key.columns
31 primary_keys = [c for c in table.primary_key.columns
32 if (c.autoincrement and
32 if (c.autoincrement and
33 isinstance(c.type, sqltypes.Integer) and
33 isinstance(c.type, sqltypes.Integer) and
34 not c.foreign_keys)]
34 not c.foreign_keys)]
35
35
36 if primary_keys:
36 if primary_keys:
37 first = primary_keys.pop(0)
37 first = primary_keys.pop(0)
38 if first.name == delta.current_name:
38 if first.name == delta.current_name:
39 colspec += " AUTO_INCREMENT"
39 colspec += " AUTO_INCREMENT"
40 old_col_name = self.preparer.quote(delta.current_name, table.quote)
40 old_col_name = self.preparer.quote(delta.current_name, table.quote)
41
41
42 self.start_alter_table(table)
42 self.start_alter_table(table)
43
43
44 self.append("CHANGE COLUMN %s " % old_col_name)
44 self.append("CHANGE COLUMN %s " % old_col_name)
45 self.append(colspec)
45 self.append(colspec)
46 self.execute()
46 self.execute()
47
47
48 def visit_index(self, param):
48 def visit_index(self, param):
49 # If MySQL can do this, I can't find how
49 # If MySQL can do this, I can't find how
50 raise exceptions.NotSupportedError("MySQL cannot rename indexes")
50 raise exceptions.NotSupportedError("MySQL cannot rename indexes")
51
51
52
52
53 class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator):
53 class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator):
54 pass
54 pass
55
55
56 if SQLA_06:
56 if SQLA_06:
57 class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper):
57 class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper):
58 def visit_migrate_check_constraint(self, *p, **k):
58 def visit_migrate_check_constraint(self, *p, **k):
59 raise exceptions.NotSupportedError("MySQL does not support CHECK"
59 raise exceptions.NotSupportedError("MySQL does not support CHECK"
60 " constraints, use triggers instead.")
60 " constraints, use triggers instead.")
61
61
62 else:
62 else:
63 class MySQLConstraintDropper(ansisql.ANSIConstraintDropper):
63 class MySQLConstraintDropper(ansisql.ANSIConstraintDropper):
64
64
65 def visit_migrate_primary_key_constraint(self, constraint):
65 def visit_migrate_primary_key_constraint(self, constraint):
66 self.start_alter_table(constraint)
66 self.start_alter_table(constraint)
67 self.append("DROP PRIMARY KEY")
67 self.append("DROP PRIMARY KEY")
68 self.execute()
68 self.execute()
69
69
70 def visit_migrate_foreign_key_constraint(self, constraint):
70 def visit_migrate_foreign_key_constraint(self, constraint):
71 self.start_alter_table(constraint)
71 self.start_alter_table(constraint)
72 self.append("DROP FOREIGN KEY ")
72 self.append("DROP FOREIGN KEY ")
73 constraint.name = self.get_constraint_name(constraint)
73 constraint.name = self.get_constraint_name(constraint)
74 self.append(self.preparer.format_constraint(constraint))
74 self.append(self.preparer.format_constraint(constraint))
75 self.execute()
75 self.execute()
76
76
77 def visit_migrate_check_constraint(self, *p, **k):
77 def visit_migrate_check_constraint(self, *p, **k):
78 raise exceptions.NotSupportedError("MySQL does not support CHECK"
78 raise exceptions.NotSupportedError("MySQL does not support CHECK"
79 " constraints, use triggers instead.")
79 " constraints, use triggers instead.")
80
80
81 def visit_migrate_unique_constraint(self, constraint, *p, **k):
81 def visit_migrate_unique_constraint(self, constraint, *p, **k):
82 self.start_alter_table(constraint)
82 self.start_alter_table(constraint)
83 self.append('DROP INDEX ')
83 self.append('DROP INDEX ')
84 constraint.name = self.get_constraint_name(constraint)
84 constraint.name = self.get_constraint_name(constraint)
85 self.append(self.preparer.format_constraint(constraint))
85 self.append(self.preparer.format_constraint(constraint))
86 self.execute()
86 self.execute()
87
87
88
88
89 class MySQLDialect(ansisql.ANSIDialect):
89 class MySQLDialect(ansisql.ANSIDialect):
90 columngenerator = MySQLColumnGenerator
90 columngenerator = MySQLColumnGenerator
91 columndropper = MySQLColumnDropper
91 columndropper = MySQLColumnDropper
92 schemachanger = MySQLSchemaChanger
92 schemachanger = MySQLSchemaChanger
93 constraintgenerator = MySQLConstraintGenerator
93 constraintgenerator = MySQLConstraintGenerator
94 constraintdropper = MySQLConstraintDropper
94 constraintdropper = MySQLConstraintDropper
@@ -1,111 +1,111 b''
1 """
1 """
2 Oracle database specific implementations of changeset classes.
2 Oracle database specific implementations of changeset classes.
3 """
3 """
4 import sqlalchemy as sa
4 import sqlalchemy as sa
5 from sqlalchemy.databases import oracle as sa_base
5 from sqlalchemy.databases import oracle as sa_base
6
6
7 from migrate import exceptions
7 from rhodecode.lib.dbmigrate.migrate import exceptions
8 from migrate.changeset import ansisql, SQLA_06
8 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
9
9
10
10
11 if not SQLA_06:
11 if not SQLA_06:
12 OracleSchemaGenerator = sa_base.OracleSchemaGenerator
12 OracleSchemaGenerator = sa_base.OracleSchemaGenerator
13 else:
13 else:
14 OracleSchemaGenerator = sa_base.OracleDDLCompiler
14 OracleSchemaGenerator = sa_base.OracleDDLCompiler
15
15
16
16
17 class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
17 class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
18 pass
18 pass
19
19
20
20
21 class OracleColumnDropper(ansisql.ANSIColumnDropper):
21 class OracleColumnDropper(ansisql.ANSIColumnDropper):
22 pass
22 pass
23
23
24
24
25 class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger):
25 class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger):
26
26
27 def get_column_specification(self, column, **kwargs):
27 def get_column_specification(self, column, **kwargs):
28 # Ignore the NOT NULL generated
28 # Ignore the NOT NULL generated
29 override_nullable = kwargs.pop('override_nullable', None)
29 override_nullable = kwargs.pop('override_nullable', None)
30 if override_nullable:
30 if override_nullable:
31 orig = column.nullable
31 orig = column.nullable
32 column.nullable = True
32 column.nullable = True
33 ret = super(OracleSchemaChanger, self).get_column_specification(
33 ret = super(OracleSchemaChanger, self).get_column_specification(
34 column, **kwargs)
34 column, **kwargs)
35 if override_nullable:
35 if override_nullable:
36 column.nullable = orig
36 column.nullable = orig
37 return ret
37 return ret
38
38
39 def visit_column(self, delta):
39 def visit_column(self, delta):
40 keys = delta.keys()
40 keys = delta.keys()
41
41
42 if 'name' in keys:
42 if 'name' in keys:
43 self._run_subvisit(delta,
43 self._run_subvisit(delta,
44 self._visit_column_name,
44 self._visit_column_name,
45 start_alter=False)
45 start_alter=False)
46
46
47 if len(set(('type', 'nullable', 'server_default')).intersection(keys)):
47 if len(set(('type', 'nullable', 'server_default')).intersection(keys)):
48 self._run_subvisit(delta,
48 self._run_subvisit(delta,
49 self._visit_column_change,
49 self._visit_column_change,
50 start_alter=False)
50 start_alter=False)
51
51
52 def _visit_column_change(self, table, column, delta):
52 def _visit_column_change(self, table, column, delta):
53 # Oracle cannot drop a default once created, but it can set it
53 # Oracle cannot drop a default once created, but it can set it
54 # to null. We'll do that if default=None
54 # to null. We'll do that if default=None
55 # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234
55 # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234
56 dropdefault_hack = (column.server_default is None \
56 dropdefault_hack = (column.server_default is None \
57 and 'server_default' in delta.keys())
57 and 'server_default' in delta.keys())
58 # Oracle apparently doesn't like it when we say "not null" if
58 # Oracle apparently doesn't like it when we say "not null" if
59 # the column's already not null. Fudge it, so we don't need a
59 # the column's already not null. Fudge it, so we don't need a
60 # new function
60 # new function
61 notnull_hack = ((not column.nullable) \
61 notnull_hack = ((not column.nullable) \
62 and ('nullable' not in delta.keys()))
62 and ('nullable' not in delta.keys()))
63 # We need to specify NULL if we're removing a NOT NULL
63 # We need to specify NULL if we're removing a NOT NULL
64 # constraint
64 # constraint
65 null_hack = (column.nullable and ('nullable' in delta.keys()))
65 null_hack = (column.nullable and ('nullable' in delta.keys()))
66
66
67 if dropdefault_hack:
67 if dropdefault_hack:
68 column.server_default = sa.PassiveDefault(sa.sql.null())
68 column.server_default = sa.PassiveDefault(sa.sql.null())
69 if notnull_hack:
69 if notnull_hack:
70 column.nullable = True
70 column.nullable = True
71 colspec = self.get_column_specification(column,
71 colspec = self.get_column_specification(column,
72 override_nullable=null_hack)
72 override_nullable=null_hack)
73 if null_hack:
73 if null_hack:
74 colspec += ' NULL'
74 colspec += ' NULL'
75 if notnull_hack:
75 if notnull_hack:
76 column.nullable = False
76 column.nullable = False
77 if dropdefault_hack:
77 if dropdefault_hack:
78 column.server_default = None
78 column.server_default = None
79
79
80 self.start_alter_table(table)
80 self.start_alter_table(table)
81 self.append("MODIFY (")
81 self.append("MODIFY (")
82 self.append(colspec)
82 self.append(colspec)
83 self.append(")")
83 self.append(")")
84
84
85
85
86 class OracleConstraintCommon(object):
86 class OracleConstraintCommon(object):
87
87
88 def get_constraint_name(self, cons):
88 def get_constraint_name(self, cons):
89 # Oracle constraints can't guess their name like other DBs
89 # Oracle constraints can't guess their name like other DBs
90 if not cons.name:
90 if not cons.name:
91 raise exceptions.NotSupportedError(
91 raise exceptions.NotSupportedError(
92 "Oracle constraint names must be explicitly stated")
92 "Oracle constraint names must be explicitly stated")
93 return cons.name
93 return cons.name
94
94
95
95
96 class OracleConstraintGenerator(OracleConstraintCommon,
96 class OracleConstraintGenerator(OracleConstraintCommon,
97 ansisql.ANSIConstraintGenerator):
97 ansisql.ANSIConstraintGenerator):
98 pass
98 pass
99
99
100
100
101 class OracleConstraintDropper(OracleConstraintCommon,
101 class OracleConstraintDropper(OracleConstraintCommon,
102 ansisql.ANSIConstraintDropper):
102 ansisql.ANSIConstraintDropper):
103 pass
103 pass
104
104
105
105
106 class OracleDialect(ansisql.ANSIDialect):
106 class OracleDialect(ansisql.ANSIDialect):
107 columngenerator = OracleColumnGenerator
107 columngenerator = OracleColumnGenerator
108 columndropper = OracleColumnDropper
108 columndropper = OracleColumnDropper
109 schemachanger = OracleSchemaChanger
109 schemachanger = OracleSchemaChanger
110 constraintgenerator = OracleConstraintGenerator
110 constraintgenerator = OracleConstraintGenerator
111 constraintdropper = OracleConstraintDropper
111 constraintdropper = OracleConstraintDropper
@@ -1,46 +1,46 b''
1 """
1 """
2 `PostgreSQL`_ database specific implementations of changeset classes.
2 `PostgreSQL`_ database specific implementations of changeset classes.
3
3
4 .. _`PostgreSQL`: http://www.postgresql.org/
4 .. _`PostgreSQL`: http://www.postgresql.org/
5 """
5 """
6 from migrate.changeset import ansisql, SQLA_06
6 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
7
7
8 if not SQLA_06:
8 if not SQLA_06:
9 from sqlalchemy.databases import postgres as sa_base
9 from sqlalchemy.databases import postgres as sa_base
10 PGSchemaGenerator = sa_base.PGSchemaGenerator
10 PGSchemaGenerator = sa_base.PGSchemaGenerator
11 else:
11 else:
12 from sqlalchemy.databases import postgresql as sa_base
12 from sqlalchemy.databases import postgresql as sa_base
13 PGSchemaGenerator = sa_base.PGDDLCompiler
13 PGSchemaGenerator = sa_base.PGDDLCompiler
14
14
15
15
16 class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator):
16 class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator):
17 """PostgreSQL column generator implementation."""
17 """PostgreSQL column generator implementation."""
18 pass
18 pass
19
19
20
20
21 class PGColumnDropper(ansisql.ANSIColumnDropper):
21 class PGColumnDropper(ansisql.ANSIColumnDropper):
22 """PostgreSQL column dropper implementation."""
22 """PostgreSQL column dropper implementation."""
23 pass
23 pass
24
24
25
25
26 class PGSchemaChanger(ansisql.ANSISchemaChanger):
26 class PGSchemaChanger(ansisql.ANSISchemaChanger):
27 """PostgreSQL schema changer implementation."""
27 """PostgreSQL schema changer implementation."""
28 pass
28 pass
29
29
30
30
31 class PGConstraintGenerator(ansisql.ANSIConstraintGenerator):
31 class PGConstraintGenerator(ansisql.ANSIConstraintGenerator):
32 """PostgreSQL constraint generator implementation."""
32 """PostgreSQL constraint generator implementation."""
33 pass
33 pass
34
34
35
35
36 class PGConstraintDropper(ansisql.ANSIConstraintDropper):
36 class PGConstraintDropper(ansisql.ANSIConstraintDropper):
37 """PostgreSQL constaint dropper implementation."""
37 """PostgreSQL constaint dropper implementation."""
38 pass
38 pass
39
39
40
40
41 class PGDialect(ansisql.ANSIDialect):
41 class PGDialect(ansisql.ANSIDialect):
42 columngenerator = PGColumnGenerator
42 columngenerator = PGColumnGenerator
43 columndropper = PGColumnDropper
43 columndropper = PGColumnDropper
44 schemachanger = PGSchemaChanger
44 schemachanger = PGSchemaChanger
45 constraintgenerator = PGConstraintGenerator
45 constraintgenerator = PGConstraintGenerator
46 constraintdropper = PGConstraintDropper
46 constraintdropper = PGConstraintDropper
@@ -1,148 +1,148 b''
1 """
1 """
2 `SQLite`_ database specific implementations of changeset classes.
2 `SQLite`_ database specific implementations of changeset classes.
3
3
4 .. _`SQLite`: http://www.sqlite.org/
4 .. _`SQLite`: http://www.sqlite.org/
5 """
5 """
6 from UserDict import DictMixin
6 from UserDict import DictMixin
7 from copy import copy
7 from copy import copy
8
8
9 from sqlalchemy.databases import sqlite as sa_base
9 from sqlalchemy.databases import sqlite as sa_base
10
10
11 from migrate import exceptions
11 from rhodecode.lib.dbmigrate.migrate import exceptions
12 from migrate.changeset import ansisql, SQLA_06
12 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06
13
13
14
14
15 if not SQLA_06:
15 if not SQLA_06:
16 SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
16 SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
17 else:
17 else:
18 SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
18 SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
19
19
20 class SQLiteCommon(object):
20 class SQLiteCommon(object):
21
21
22 def _not_supported(self, op):
22 def _not_supported(self, op):
23 raise exceptions.NotSupportedError("SQLite does not support "
23 raise exceptions.NotSupportedError("SQLite does not support "
24 "%s; see http://www.sqlite.org/lang_altertable.html" % op)
24 "%s; see http://www.sqlite.org/lang_altertable.html" % op)
25
25
26
26
27 class SQLiteHelper(SQLiteCommon):
27 class SQLiteHelper(SQLiteCommon):
28
28
29 def recreate_table(self,table,column=None,delta=None):
29 def recreate_table(self,table,column=None,delta=None):
30 table_name = self.preparer.format_table(table)
30 table_name = self.preparer.format_table(table)
31
31
32 # we remove all indexes so as not to have
32 # we remove all indexes so as not to have
33 # problems during copy and re-create
33 # problems during copy and re-create
34 for index in table.indexes:
34 for index in table.indexes:
35 index.drop()
35 index.drop()
36
36
37 self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
37 self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
38 self.execute()
38 self.execute()
39
39
40 insertion_string = self._modify_table(table, column, delta)
40 insertion_string = self._modify_table(table, column, delta)
41
41
42 table.create()
42 table.create()
43 self.append(insertion_string % {'table_name': table_name})
43 self.append(insertion_string % {'table_name': table_name})
44 self.execute()
44 self.execute()
45 self.append('DROP TABLE migration_tmp')
45 self.append('DROP TABLE migration_tmp')
46 self.execute()
46 self.execute()
47
47
48 def visit_column(self, delta):
48 def visit_column(self, delta):
49 if isinstance(delta, DictMixin):
49 if isinstance(delta, DictMixin):
50 column = delta.result_column
50 column = delta.result_column
51 table = self._to_table(delta.table)
51 table = self._to_table(delta.table)
52 else:
52 else:
53 column = delta
53 column = delta
54 table = self._to_table(column.table)
54 table = self._to_table(column.table)
55 self.recreate_table(table,column,delta)
55 self.recreate_table(table,column,delta)
56
56
57 class SQLiteColumnGenerator(SQLiteSchemaGenerator,
57 class SQLiteColumnGenerator(SQLiteSchemaGenerator,
58 ansisql.ANSIColumnGenerator,
58 ansisql.ANSIColumnGenerator,
59 # at the end so we get the normal
59 # at the end so we get the normal
60 # visit_column by default
60 # visit_column by default
61 SQLiteHelper,
61 SQLiteHelper,
62 SQLiteCommon
62 SQLiteCommon
63 ):
63 ):
64 """SQLite ColumnGenerator"""
64 """SQLite ColumnGenerator"""
65
65
66 def _modify_table(self, table, column, delta):
66 def _modify_table(self, table, column, delta):
67 columns = ' ,'.join(map(
67 columns = ' ,'.join(map(
68 self.preparer.format_column,
68 self.preparer.format_column,
69 [c for c in table.columns if c.name!=column.name]))
69 [c for c in table.columns if c.name!=column.name]))
70 return ('INSERT INTO %%(table_name)s (%(cols)s) '
70 return ('INSERT INTO %%(table_name)s (%(cols)s) '
71 'SELECT %(cols)s from migration_tmp')%{'cols':columns}
71 'SELECT %(cols)s from migration_tmp')%{'cols':columns}
72
72
73 def visit_column(self,column):
73 def visit_column(self,column):
74 if column.foreign_keys:
74 if column.foreign_keys:
75 SQLiteHelper.visit_column(self,column)
75 SQLiteHelper.visit_column(self,column)
76 else:
76 else:
77 super(SQLiteColumnGenerator,self).visit_column(column)
77 super(SQLiteColumnGenerator,self).visit_column(column)
78
78
79 class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper):
79 class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper):
80 """SQLite ColumnDropper"""
80 """SQLite ColumnDropper"""
81
81
82 def _modify_table(self, table, column, delta):
82 def _modify_table(self, table, column, delta):
83 columns = ' ,'.join(map(self.preparer.format_column, table.columns))
83 columns = ' ,'.join(map(self.preparer.format_column, table.columns))
84 return 'INSERT INTO %(table_name)s SELECT ' + columns + \
84 return 'INSERT INTO %(table_name)s SELECT ' + columns + \
85 ' from migration_tmp'
85 ' from migration_tmp'
86
86
87
87
88 class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger):
88 class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger):
89 """SQLite SchemaChanger"""
89 """SQLite SchemaChanger"""
90
90
91 def _modify_table(self, table, column, delta):
91 def _modify_table(self, table, column, delta):
92 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
92 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
93
93
94 def visit_index(self, index):
94 def visit_index(self, index):
95 """Does not support ALTER INDEX"""
95 """Does not support ALTER INDEX"""
96 self._not_supported('ALTER INDEX')
96 self._not_supported('ALTER INDEX')
97
97
98
98
99 class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon):
99 class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon):
100
100
101 def visit_migrate_primary_key_constraint(self, constraint):
101 def visit_migrate_primary_key_constraint(self, constraint):
102 tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )"
102 tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )"
103 cols = ', '.join(map(self.preparer.format_column, constraint.columns))
103 cols = ', '.join(map(self.preparer.format_column, constraint.columns))
104 tname = self.preparer.format_table(constraint.table)
104 tname = self.preparer.format_table(constraint.table)
105 name = self.get_constraint_name(constraint)
105 name = self.get_constraint_name(constraint)
106 msg = tmpl % (name, tname, cols)
106 msg = tmpl % (name, tname, cols)
107 self.append(msg)
107 self.append(msg)
108 self.execute()
108 self.execute()
109
109
110 def _modify_table(self, table, column, delta):
110 def _modify_table(self, table, column, delta):
111 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
111 return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
112
112
113 def visit_migrate_foreign_key_constraint(self, *p, **k):
113 def visit_migrate_foreign_key_constraint(self, *p, **k):
114 self.recreate_table(p[0].table)
114 self.recreate_table(p[0].table)
115
115
116 def visit_migrate_unique_constraint(self, *p, **k):
116 def visit_migrate_unique_constraint(self, *p, **k):
117 self.recreate_table(p[0].table)
117 self.recreate_table(p[0].table)
118
118
119
119
120 class SQLiteConstraintDropper(ansisql.ANSIColumnDropper,
120 class SQLiteConstraintDropper(ansisql.ANSIColumnDropper,
121 SQLiteCommon,
121 SQLiteCommon,
122 ansisql.ANSIConstraintCommon):
122 ansisql.ANSIConstraintCommon):
123
123
124 def visit_migrate_primary_key_constraint(self, constraint):
124 def visit_migrate_primary_key_constraint(self, constraint):
125 tmpl = "DROP INDEX %s "
125 tmpl = "DROP INDEX %s "
126 name = self.get_constraint_name(constraint)
126 name = self.get_constraint_name(constraint)
127 msg = tmpl % (name)
127 msg = tmpl % (name)
128 self.append(msg)
128 self.append(msg)
129 self.execute()
129 self.execute()
130
130
131 def visit_migrate_foreign_key_constraint(self, *p, **k):
131 def visit_migrate_foreign_key_constraint(self, *p, **k):
132 self._not_supported('ALTER TABLE DROP CONSTRAINT')
132 self._not_supported('ALTER TABLE DROP CONSTRAINT')
133
133
134 def visit_migrate_check_constraint(self, *p, **k):
134 def visit_migrate_check_constraint(self, *p, **k):
135 self._not_supported('ALTER TABLE DROP CONSTRAINT')
135 self._not_supported('ALTER TABLE DROP CONSTRAINT')
136
136
137 def visit_migrate_unique_constraint(self, *p, **k):
137 def visit_migrate_unique_constraint(self, *p, **k):
138 self._not_supported('ALTER TABLE DROP CONSTRAINT')
138 self._not_supported('ALTER TABLE DROP CONSTRAINT')
139
139
140
140
141 # TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index
141 # TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index
142
142
143 class SQLiteDialect(ansisql.ANSIDialect):
143 class SQLiteDialect(ansisql.ANSIDialect):
144 columngenerator = SQLiteColumnGenerator
144 columngenerator = SQLiteColumnGenerator
145 columndropper = SQLiteColumnDropper
145 columndropper = SQLiteColumnDropper
146 schemachanger = SQLiteSchemaChanger
146 schemachanger = SQLiteSchemaChanger
147 constraintgenerator = SQLiteConstraintGenerator
147 constraintgenerator = SQLiteConstraintGenerator
148 constraintdropper = SQLiteConstraintDropper
148 constraintdropper = SQLiteConstraintDropper
@@ -1,78 +1,78 b''
1 """
1 """
2 Module for visitor class mapping.
2 Module for visitor class mapping.
3 """
3 """
4 import sqlalchemy as sa
4 import sqlalchemy as sa
5
5
6 from migrate.changeset import ansisql
6 from rhodecode.lib.dbmigrate.migrate.changeset import ansisql
7 from migrate.changeset.databases import (sqlite,
7 from rhodecode.lib.dbmigrate.migrate.changeset.databases import (sqlite,
8 postgres,
8 postgres,
9 mysql,
9 mysql,
10 oracle,
10 oracle,
11 firebird)
11 firebird)
12
12
13
13
14 # Map SA dialects to the corresponding Migrate extensions
14 # Map SA dialects to the corresponding Migrate extensions
15 DIALECTS = {
15 DIALECTS = {
16 "default": ansisql.ANSIDialect,
16 "default": ansisql.ANSIDialect,
17 "sqlite": sqlite.SQLiteDialect,
17 "sqlite": sqlite.SQLiteDialect,
18 "postgres": postgres.PGDialect,
18 "postgres": postgres.PGDialect,
19 "postgresql": postgres.PGDialect,
19 "postgresql": postgres.PGDialect,
20 "mysql": mysql.MySQLDialect,
20 "mysql": mysql.MySQLDialect,
21 "oracle": oracle.OracleDialect,
21 "oracle": oracle.OracleDialect,
22 "firebird": firebird.FBDialect,
22 "firebird": firebird.FBDialect,
23 }
23 }
24
24
25
25
26 def get_engine_visitor(engine, name):
26 def get_engine_visitor(engine, name):
27 """
27 """
28 Get the visitor implementation for the given database engine.
28 Get the visitor implementation for the given database engine.
29
29
30 :param engine: SQLAlchemy Engine
30 :param engine: SQLAlchemy Engine
31 :param name: Name of the visitor
31 :param name: Name of the visitor
32 :type name: string
32 :type name: string
33 :type engine: Engine
33 :type engine: Engine
34 :returns: visitor
34 :returns: visitor
35 """
35 """
36 # TODO: link to supported visitors
36 # TODO: link to supported visitors
37 return get_dialect_visitor(engine.dialect, name)
37 return get_dialect_visitor(engine.dialect, name)
38
38
39
39
40 def get_dialect_visitor(sa_dialect, name):
40 def get_dialect_visitor(sa_dialect, name):
41 """
41 """
42 Get the visitor implementation for the given dialect.
42 Get the visitor implementation for the given dialect.
43
43
44 Finds the visitor implementation based on the dialect class and
44 Finds the visitor implementation based on the dialect class and
45 returns and instance initialized with the given name.
45 returns and instance initialized with the given name.
46
46
47 Binds dialect specific preparer to visitor.
47 Binds dialect specific preparer to visitor.
48 """
48 """
49
49
50 # map sa dialect to migrate dialect and return visitor
50 # map sa dialect to migrate dialect and return visitor
51 sa_dialect_name = getattr(sa_dialect, 'name', 'default')
51 sa_dialect_name = getattr(sa_dialect, 'name', 'default')
52 migrate_dialect_cls = DIALECTS[sa_dialect_name]
52 migrate_dialect_cls = DIALECTS[sa_dialect_name]
53 visitor = getattr(migrate_dialect_cls, name)
53 visitor = getattr(migrate_dialect_cls, name)
54
54
55 # bind preparer
55 # bind preparer
56 visitor.preparer = sa_dialect.preparer(sa_dialect)
56 visitor.preparer = sa_dialect.preparer(sa_dialect)
57
57
58 return visitor
58 return visitor
59
59
60 def run_single_visitor(engine, visitorcallable, element,
60 def run_single_visitor(engine, visitorcallable, element,
61 connection=None, **kwargs):
61 connection=None, **kwargs):
62 """Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor`
62 """Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor`
63 with support for migrate visitors.
63 with support for migrate visitors.
64 """
64 """
65 if connection is None:
65 if connection is None:
66 conn = engine.contextual_connect(close_with_result=False)
66 conn = engine.contextual_connect(close_with_result=False)
67 else:
67 else:
68 conn = connection
68 conn = connection
69 visitor = visitorcallable(engine.dialect, conn)
69 visitor = visitorcallable(engine.dialect, conn)
70 try:
70 try:
71 if hasattr(element, '__migrate_visit_name__'):
71 if hasattr(element, '__migrate_visit_name__'):
72 fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__)
72 fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__)
73 else:
73 else:
74 fn = getattr(visitor, 'visit_' + element.__visit_name__)
74 fn = getattr(visitor, 'visit_' + element.__visit_name__)
75 fn(element, **kwargs)
75 fn(element, **kwargs)
76 finally:
76 finally:
77 if connection is None:
77 if connection is None:
78 conn.close()
78 conn.close()
@@ -1,669 +1,669 b''
1 """
1 """
2 Schema module providing common schema operations.
2 Schema module providing common schema operations.
3 """
3 """
4 import warnings
4 import warnings
5
5
6 from UserDict import DictMixin
6 from UserDict import DictMixin
7
7
8 import sqlalchemy
8 import sqlalchemy
9
9
10 from sqlalchemy.schema import ForeignKeyConstraint
10 from sqlalchemy.schema import ForeignKeyConstraint
11 from sqlalchemy.schema import UniqueConstraint
11 from sqlalchemy.schema import UniqueConstraint
12
12
13 from migrate.exceptions import *
13 from rhodecode.lib.dbmigrate.migrate.exceptions import *
14 from migrate.changeset import SQLA_06
14 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
15 from migrate.changeset.databases.visitor import (get_engine_visitor,
15 from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor,
16 run_single_visitor)
16 run_single_visitor)
17
17
18
18
19 __all__ = [
19 __all__ = [
20 'create_column',
20 'create_column',
21 'drop_column',
21 'drop_column',
22 'alter_column',
22 'alter_column',
23 'rename_table',
23 'rename_table',
24 'rename_index',
24 'rename_index',
25 'ChangesetTable',
25 'ChangesetTable',
26 'ChangesetColumn',
26 'ChangesetColumn',
27 'ChangesetIndex',
27 'ChangesetIndex',
28 'ChangesetDefaultClause',
28 'ChangesetDefaultClause',
29 'ColumnDelta',
29 'ColumnDelta',
30 ]
30 ]
31
31
32 DEFAULT_ALTER_METADATA = True
32 DEFAULT_ALTER_METADATA = True
33
33
34
34
35 def create_column(column, table=None, *p, **kw):
35 def create_column(column, table=None, *p, **kw):
36 """Create a column, given the table.
36 """Create a column, given the table.
37
37
38 API to :meth:`ChangesetColumn.create`.
38 API to :meth:`ChangesetColumn.create`.
39 """
39 """
40 if table is not None:
40 if table is not None:
41 return table.create_column(column, *p, **kw)
41 return table.create_column(column, *p, **kw)
42 return column.create(*p, **kw)
42 return column.create(*p, **kw)
43
43
44
44
45 def drop_column(column, table=None, *p, **kw):
45 def drop_column(column, table=None, *p, **kw):
46 """Drop a column, given the table.
46 """Drop a column, given the table.
47
47
48 API to :meth:`ChangesetColumn.drop`.
48 API to :meth:`ChangesetColumn.drop`.
49 """
49 """
50 if table is not None:
50 if table is not None:
51 return table.drop_column(column, *p, **kw)
51 return table.drop_column(column, *p, **kw)
52 return column.drop(*p, **kw)
52 return column.drop(*p, **kw)
53
53
54
54
55 def rename_table(table, name, engine=None, **kw):
55 def rename_table(table, name, engine=None, **kw):
56 """Rename a table.
56 """Rename a table.
57
57
58 If Table instance is given, engine is not used.
58 If Table instance is given, engine is not used.
59
59
60 API to :meth:`ChangesetTable.rename`.
60 API to :meth:`ChangesetTable.rename`.
61
61
62 :param table: Table to be renamed.
62 :param table: Table to be renamed.
63 :param name: New name for Table.
63 :param name: New name for Table.
64 :param engine: Engine instance.
64 :param engine: Engine instance.
65 :type table: string or Table instance
65 :type table: string or Table instance
66 :type name: string
66 :type name: string
67 :type engine: obj
67 :type engine: obj
68 """
68 """
69 table = _to_table(table, engine)
69 table = _to_table(table, engine)
70 table.rename(name, **kw)
70 table.rename(name, **kw)
71
71
72
72
73 def rename_index(index, name, table=None, engine=None, **kw):
73 def rename_index(index, name, table=None, engine=None, **kw):
74 """Rename an index.
74 """Rename an index.
75
75
76 If Index instance is given,
76 If Index instance is given,
77 table and engine are not used.
77 table and engine are not used.
78
78
79 API to :meth:`ChangesetIndex.rename`.
79 API to :meth:`ChangesetIndex.rename`.
80
80
81 :param index: Index to be renamed.
81 :param index: Index to be renamed.
82 :param name: New name for index.
82 :param name: New name for index.
83 :param table: Table to which Index is reffered.
83 :param table: Table to which Index is reffered.
84 :param engine: Engine instance.
84 :param engine: Engine instance.
85 :type index: string or Index instance
85 :type index: string or Index instance
86 :type name: string
86 :type name: string
87 :type table: string or Table instance
87 :type table: string or Table instance
88 :type engine: obj
88 :type engine: obj
89 """
89 """
90 index = _to_index(index, table, engine)
90 index = _to_index(index, table, engine)
91 index.rename(name, **kw)
91 index.rename(name, **kw)
92
92
93
93
94 def alter_column(*p, **k):
94 def alter_column(*p, **k):
95 """Alter a column.
95 """Alter a column.
96
96
97 This is a helper function that creates a :class:`ColumnDelta` and
97 This is a helper function that creates a :class:`ColumnDelta` and
98 runs it.
98 runs it.
99
99
100 :argument column:
100 :argument column:
101 The name of the column to be altered or a
101 The name of the column to be altered or a
102 :class:`ChangesetColumn` column representing it.
102 :class:`ChangesetColumn` column representing it.
103
103
104 :param table:
104 :param table:
105 A :class:`~sqlalchemy.schema.Table` or table name to
105 A :class:`~sqlalchemy.schema.Table` or table name to
106 for the table where the column will be changed.
106 for the table where the column will be changed.
107
107
108 :param engine:
108 :param engine:
109 The :class:`~sqlalchemy.engine.base.Engine` to use for table
109 The :class:`~sqlalchemy.engine.base.Engine` to use for table
110 reflection and schema alterations.
110 reflection and schema alterations.
111
111
112 :param alter_metadata:
112 :param alter_metadata:
113 If `True`, which is the default, the
113 If `True`, which is the default, the
114 :class:`~sqlalchemy.schema.Column` will also modified.
114 :class:`~sqlalchemy.schema.Column` will also modified.
115 If `False`, the :class:`~sqlalchemy.schema.Column` will be left
115 If `False`, the :class:`~sqlalchemy.schema.Column` will be left
116 as it was.
116 as it was.
117
117
118 :returns: A :class:`ColumnDelta` instance representing the change.
118 :returns: A :class:`ColumnDelta` instance representing the change.
119
119
120
120
121 """
121 """
122
122
123 k.setdefault('alter_metadata', DEFAULT_ALTER_METADATA)
123 k.setdefault('alter_metadata', DEFAULT_ALTER_METADATA)
124
124
125 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
125 if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
126 k['table'] = p[0].table
126 k['table'] = p[0].table
127 if 'engine' not in k:
127 if 'engine' not in k:
128 k['engine'] = k['table'].bind
128 k['engine'] = k['table'].bind
129
129
130 # deprecation
130 # deprecation
131 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
131 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
132 warnings.warn(
132 warnings.warn(
133 "Passing a Column object to alter_column is deprecated."
133 "Passing a Column object to alter_column is deprecated."
134 " Just pass in keyword parameters instead.",
134 " Just pass in keyword parameters instead.",
135 MigrateDeprecationWarning
135 MigrateDeprecationWarning
136 )
136 )
137 engine = k['engine']
137 engine = k['engine']
138 delta = ColumnDelta(*p, **k)
138 delta = ColumnDelta(*p, **k)
139
139
140 visitorcallable = get_engine_visitor(engine, 'schemachanger')
140 visitorcallable = get_engine_visitor(engine, 'schemachanger')
141 engine._run_visitor(visitorcallable, delta)
141 engine._run_visitor(visitorcallable, delta)
142
142
143 return delta
143 return delta
144
144
145
145
146 def _to_table(table, engine=None):
146 def _to_table(table, engine=None):
147 """Return if instance of Table, else construct new with metadata"""
147 """Return if instance of Table, else construct new with metadata"""
148 if isinstance(table, sqlalchemy.Table):
148 if isinstance(table, sqlalchemy.Table):
149 return table
149 return table
150
150
151 # Given: table name, maybe an engine
151 # Given: table name, maybe an engine
152 meta = sqlalchemy.MetaData()
152 meta = sqlalchemy.MetaData()
153 if engine is not None:
153 if engine is not None:
154 meta.bind = engine
154 meta.bind = engine
155 return sqlalchemy.Table(table, meta)
155 return sqlalchemy.Table(table, meta)
156
156
157
157
158 def _to_index(index, table=None, engine=None):
158 def _to_index(index, table=None, engine=None):
159 """Return if instance of Index, else construct new with metadata"""
159 """Return if instance of Index, else construct new with metadata"""
160 if isinstance(index, sqlalchemy.Index):
160 if isinstance(index, sqlalchemy.Index):
161 return index
161 return index
162
162
163 # Given: index name; table name required
163 # Given: index name; table name required
164 table = _to_table(table, engine)
164 table = _to_table(table, engine)
165 ret = sqlalchemy.Index(index)
165 ret = sqlalchemy.Index(index)
166 ret.table = table
166 ret.table = table
167 return ret
167 return ret
168
168
169
169
170 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
170 class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
171 """Extracts the differences between two columns/column-parameters
171 """Extracts the differences between two columns/column-parameters
172
172
173 May receive parameters arranged in several different ways:
173 May receive parameters arranged in several different ways:
174
174
175 * **current_column, new_column, \*p, \*\*kw**
175 * **current_column, new_column, \*p, \*\*kw**
176 Additional parameters can be specified to override column
176 Additional parameters can be specified to override column
177 differences.
177 differences.
178
178
179 * **current_column, \*p, \*\*kw**
179 * **current_column, \*p, \*\*kw**
180 Additional parameters alter current_column. Table name is extracted
180 Additional parameters alter current_column. Table name is extracted
181 from current_column object.
181 from current_column object.
182 Name is changed to current_column.name from current_name,
182 Name is changed to current_column.name from current_name,
183 if current_name is specified.
183 if current_name is specified.
184
184
185 * **current_col_name, \*p, \*\*kw**
185 * **current_col_name, \*p, \*\*kw**
186 Table kw must specified.
186 Table kw must specified.
187
187
188 :param table: Table at which current Column should be bound to.\
188 :param table: Table at which current Column should be bound to.\
189 If table name is given, reflection will be used.
189 If table name is given, reflection will be used.
190 :type table: string or Table instance
190 :type table: string or Table instance
191 :param alter_metadata: If True, it will apply changes to metadata.
191 :param alter_metadata: If True, it will apply changes to metadata.
192 :type alter_metadata: bool
192 :type alter_metadata: bool
193 :param metadata: If `alter_metadata` is true, \
193 :param metadata: If `alter_metadata` is true, \
194 metadata is used to reflect table names into
194 metadata is used to reflect table names into
195 :type metadata: :class:`MetaData` instance
195 :type metadata: :class:`MetaData` instance
196 :param engine: When reflecting tables, either engine or metadata must \
196 :param engine: When reflecting tables, either engine or metadata must \
197 be specified to acquire engine object.
197 be specified to acquire engine object.
198 :type engine: :class:`Engine` instance
198 :type engine: :class:`Engine` instance
199 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
199 :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
200 `result_column` through :func:`dict` alike object.
200 `result_column` through :func:`dict` alike object.
201
201
202 * :class:`ColumnDelta`.result_column is altered column with new attributes
202 * :class:`ColumnDelta`.result_column is altered column with new attributes
203
203
204 * :class:`ColumnDelta`.current_name is current name of column in db
204 * :class:`ColumnDelta`.current_name is current name of column in db
205
205
206
206
207 """
207 """
208
208
209 # Column attributes that can be altered
209 # Column attributes that can be altered
210 diff_keys = ('name', 'type', 'primary_key', 'nullable',
210 diff_keys = ('name', 'type', 'primary_key', 'nullable',
211 'server_onupdate', 'server_default', 'autoincrement')
211 'server_onupdate', 'server_default', 'autoincrement')
212 diffs = dict()
212 diffs = dict()
213 __visit_name__ = 'column'
213 __visit_name__ = 'column'
214
214
215 def __init__(self, *p, **kw):
215 def __init__(self, *p, **kw):
216 self.alter_metadata = kw.pop("alter_metadata", False)
216 self.alter_metadata = kw.pop("alter_metadata", False)
217 self.meta = kw.pop("metadata", None)
217 self.meta = kw.pop("metadata", None)
218 self.engine = kw.pop("engine", None)
218 self.engine = kw.pop("engine", None)
219
219
220 # Things are initialized differently depending on how many column
220 # Things are initialized differently depending on how many column
221 # parameters are given. Figure out how many and call the appropriate
221 # parameters are given. Figure out how many and call the appropriate
222 # method.
222 # method.
223 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
223 if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
224 # At least one column specified
224 # At least one column specified
225 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
225 if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
226 # Two columns specified
226 # Two columns specified
227 diffs = self.compare_2_columns(*p, **kw)
227 diffs = self.compare_2_columns(*p, **kw)
228 else:
228 else:
229 # Exactly one column specified
229 # Exactly one column specified
230 diffs = self.compare_1_column(*p, **kw)
230 diffs = self.compare_1_column(*p, **kw)
231 else:
231 else:
232 # Zero columns specified
232 # Zero columns specified
233 if not len(p) or not isinstance(p[0], basestring):
233 if not len(p) or not isinstance(p[0], basestring):
234 raise ValueError("First argument must be column name")
234 raise ValueError("First argument must be column name")
235 diffs = self.compare_parameters(*p, **kw)
235 diffs = self.compare_parameters(*p, **kw)
236
236
237 self.apply_diffs(diffs)
237 self.apply_diffs(diffs)
238
238
239 def __repr__(self):
239 def __repr__(self):
240 return '<ColumnDelta altermetadata=%r, %s>' % (self.alter_metadata,
240 return '<ColumnDelta altermetadata=%r, %s>' % (self.alter_metadata,
241 super(ColumnDelta, self).__repr__())
241 super(ColumnDelta, self).__repr__())
242
242
243 def __getitem__(self, key):
243 def __getitem__(self, key):
244 if key not in self.keys():
244 if key not in self.keys():
245 raise KeyError("No such diff key, available: %s" % self.diffs)
245 raise KeyError("No such diff key, available: %s" % self.diffs)
246 return getattr(self.result_column, key)
246 return getattr(self.result_column, key)
247
247
248 def __setitem__(self, key, value):
248 def __setitem__(self, key, value):
249 if key not in self.keys():
249 if key not in self.keys():
250 raise KeyError("No such diff key, available: %s" % self.diffs)
250 raise KeyError("No such diff key, available: %s" % self.diffs)
251 setattr(self.result_column, key, value)
251 setattr(self.result_column, key, value)
252
252
253 def __delitem__(self, key):
253 def __delitem__(self, key):
254 raise NotImplementedError
254 raise NotImplementedError
255
255
256 def keys(self):
256 def keys(self):
257 return self.diffs.keys()
257 return self.diffs.keys()
258
258
259 def compare_parameters(self, current_name, *p, **k):
259 def compare_parameters(self, current_name, *p, **k):
260 """Compares Column objects with reflection"""
260 """Compares Column objects with reflection"""
261 self.table = k.pop('table')
261 self.table = k.pop('table')
262 self.result_column = self._table.c.get(current_name)
262 self.result_column = self._table.c.get(current_name)
263 if len(p):
263 if len(p):
264 k = self._extract_parameters(p, k, self.result_column)
264 k = self._extract_parameters(p, k, self.result_column)
265 return k
265 return k
266
266
267 def compare_1_column(self, col, *p, **k):
267 def compare_1_column(self, col, *p, **k):
268 """Compares one Column object"""
268 """Compares one Column object"""
269 self.table = k.pop('table', None)
269 self.table = k.pop('table', None)
270 if self.table is None:
270 if self.table is None:
271 self.table = col.table
271 self.table = col.table
272 self.result_column = col
272 self.result_column = col
273 if len(p):
273 if len(p):
274 k = self._extract_parameters(p, k, self.result_column)
274 k = self._extract_parameters(p, k, self.result_column)
275 return k
275 return k
276
276
277 def compare_2_columns(self, old_col, new_col, *p, **k):
277 def compare_2_columns(self, old_col, new_col, *p, **k):
278 """Compares two Column objects"""
278 """Compares two Column objects"""
279 self.process_column(new_col)
279 self.process_column(new_col)
280 self.table = k.pop('table', None)
280 self.table = k.pop('table', None)
281 # we cannot use bool() on table in SA06
281 # we cannot use bool() on table in SA06
282 if self.table is None:
282 if self.table is None:
283 self.table = old_col.table
283 self.table = old_col.table
284 if self.table is None:
284 if self.table is None:
285 new_col.table
285 new_col.table
286 self.result_column = old_col
286 self.result_column = old_col
287
287
288 # set differences
288 # set differences
289 # leave out some stuff for later comp
289 # leave out some stuff for later comp
290 for key in (set(self.diff_keys) - set(('type',))):
290 for key in (set(self.diff_keys) - set(('type',))):
291 val = getattr(new_col, key, None)
291 val = getattr(new_col, key, None)
292 if getattr(self.result_column, key, None) != val:
292 if getattr(self.result_column, key, None) != val:
293 k.setdefault(key, val)
293 k.setdefault(key, val)
294
294
295 # inspect types
295 # inspect types
296 if not self.are_column_types_eq(self.result_column.type, new_col.type):
296 if not self.are_column_types_eq(self.result_column.type, new_col.type):
297 k.setdefault('type', new_col.type)
297 k.setdefault('type', new_col.type)
298
298
299 if len(p):
299 if len(p):
300 k = self._extract_parameters(p, k, self.result_column)
300 k = self._extract_parameters(p, k, self.result_column)
301 return k
301 return k
302
302
303 def apply_diffs(self, diffs):
303 def apply_diffs(self, diffs):
304 """Populate dict and column object with new values"""
304 """Populate dict and column object with new values"""
305 self.diffs = diffs
305 self.diffs = diffs
306 for key in self.diff_keys:
306 for key in self.diff_keys:
307 if key in diffs:
307 if key in diffs:
308 setattr(self.result_column, key, diffs[key])
308 setattr(self.result_column, key, diffs[key])
309
309
310 self.process_column(self.result_column)
310 self.process_column(self.result_column)
311
311
312 # create an instance of class type if not yet
312 # create an instance of class type if not yet
313 if 'type' in diffs and callable(self.result_column.type):
313 if 'type' in diffs and callable(self.result_column.type):
314 self.result_column.type = self.result_column.type()
314 self.result_column.type = self.result_column.type()
315
315
316 # add column to the table
316 # add column to the table
317 if self.table is not None and self.alter_metadata:
317 if self.table is not None and self.alter_metadata:
318 self.result_column.add_to_table(self.table)
318 self.result_column.add_to_table(self.table)
319
319
320 def are_column_types_eq(self, old_type, new_type):
320 def are_column_types_eq(self, old_type, new_type):
321 """Compares two types to be equal"""
321 """Compares two types to be equal"""
322 ret = old_type.__class__ == new_type.__class__
322 ret = old_type.__class__ == new_type.__class__
323
323
324 # String length is a special case
324 # String length is a special case
325 if ret and isinstance(new_type, sqlalchemy.types.String):
325 if ret and isinstance(new_type, sqlalchemy.types.String):
326 ret = (getattr(old_type, 'length', None) == \
326 ret = (getattr(old_type, 'length', None) == \
327 getattr(new_type, 'length', None))
327 getattr(new_type, 'length', None))
328 return ret
328 return ret
329
329
330 def _extract_parameters(self, p, k, column):
330 def _extract_parameters(self, p, k, column):
331 """Extracts data from p and modifies diffs"""
331 """Extracts data from p and modifies diffs"""
332 p = list(p)
332 p = list(p)
333 while len(p):
333 while len(p):
334 if isinstance(p[0], basestring):
334 if isinstance(p[0], basestring):
335 k.setdefault('name', p.pop(0))
335 k.setdefault('name', p.pop(0))
336 elif isinstance(p[0], sqlalchemy.types.AbstractType):
336 elif isinstance(p[0], sqlalchemy.types.AbstractType):
337 k.setdefault('type', p.pop(0))
337 k.setdefault('type', p.pop(0))
338 elif callable(p[0]):
338 elif callable(p[0]):
339 p[0] = p[0]()
339 p[0] = p[0]()
340 else:
340 else:
341 break
341 break
342
342
343 if len(p):
343 if len(p):
344 new_col = column.copy_fixed()
344 new_col = column.copy_fixed()
345 new_col._init_items(*p)
345 new_col._init_items(*p)
346 k = self.compare_2_columns(column, new_col, **k)
346 k = self.compare_2_columns(column, new_col, **k)
347 return k
347 return k
348
348
349 def process_column(self, column):
349 def process_column(self, column):
350 """Processes default values for column"""
350 """Processes default values for column"""
351 # XXX: this is a snippet from SA processing of positional parameters
351 # XXX: this is a snippet from SA processing of positional parameters
352 if not SQLA_06 and column.args:
352 if not SQLA_06 and column.args:
353 toinit = list(column.args)
353 toinit = list(column.args)
354 else:
354 else:
355 toinit = list()
355 toinit = list()
356
356
357 if column.server_default is not None:
357 if column.server_default is not None:
358 if isinstance(column.server_default, sqlalchemy.FetchedValue):
358 if isinstance(column.server_default, sqlalchemy.FetchedValue):
359 toinit.append(column.server_default)
359 toinit.append(column.server_default)
360 else:
360 else:
361 toinit.append(sqlalchemy.DefaultClause(column.server_default))
361 toinit.append(sqlalchemy.DefaultClause(column.server_default))
362 if column.server_onupdate is not None:
362 if column.server_onupdate is not None:
363 if isinstance(column.server_onupdate, FetchedValue):
363 if isinstance(column.server_onupdate, FetchedValue):
364 toinit.append(column.server_default)
364 toinit.append(column.server_default)
365 else:
365 else:
366 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
366 toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
367 for_update=True))
367 for_update=True))
368 if toinit:
368 if toinit:
369 column._init_items(*toinit)
369 column._init_items(*toinit)
370
370
371 if not SQLA_06:
371 if not SQLA_06:
372 column.args = []
372 column.args = []
373
373
374 def _get_table(self):
374 def _get_table(self):
375 return getattr(self, '_table', None)
375 return getattr(self, '_table', None)
376
376
377 def _set_table(self, table):
377 def _set_table(self, table):
378 if isinstance(table, basestring):
378 if isinstance(table, basestring):
379 if self.alter_metadata:
379 if self.alter_metadata:
380 if not self.meta:
380 if not self.meta:
381 raise ValueError("metadata must be specified for table"
381 raise ValueError("metadata must be specified for table"
382 " reflection when using alter_metadata")
382 " reflection when using alter_metadata")
383 meta = self.meta
383 meta = self.meta
384 if self.engine:
384 if self.engine:
385 meta.bind = self.engine
385 meta.bind = self.engine
386 else:
386 else:
387 if not self.engine and not self.meta:
387 if not self.engine and not self.meta:
388 raise ValueError("engine or metadata must be specified"
388 raise ValueError("engine or metadata must be specified"
389 " to reflect tables")
389 " to reflect tables")
390 if not self.engine:
390 if not self.engine:
391 self.engine = self.meta.bind
391 self.engine = self.meta.bind
392 meta = sqlalchemy.MetaData(bind=self.engine)
392 meta = sqlalchemy.MetaData(bind=self.engine)
393 self._table = sqlalchemy.Table(table, meta, autoload=True)
393 self._table = sqlalchemy.Table(table, meta, autoload=True)
394 elif isinstance(table, sqlalchemy.Table):
394 elif isinstance(table, sqlalchemy.Table):
395 self._table = table
395 self._table = table
396 if not self.alter_metadata:
396 if not self.alter_metadata:
397 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
397 self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
398
398
399 def _get_result_column(self):
399 def _get_result_column(self):
400 return getattr(self, '_result_column', None)
400 return getattr(self, '_result_column', None)
401
401
402 def _set_result_column(self, column):
402 def _set_result_column(self, column):
403 """Set Column to Table based on alter_metadata evaluation."""
403 """Set Column to Table based on alter_metadata evaluation."""
404 self.process_column(column)
404 self.process_column(column)
405 if not hasattr(self, 'current_name'):
405 if not hasattr(self, 'current_name'):
406 self.current_name = column.name
406 self.current_name = column.name
407 if self.alter_metadata:
407 if self.alter_metadata:
408 self._result_column = column
408 self._result_column = column
409 else:
409 else:
410 self._result_column = column.copy_fixed()
410 self._result_column = column.copy_fixed()
411
411
412 table = property(_get_table, _set_table)
412 table = property(_get_table, _set_table)
413 result_column = property(_get_result_column, _set_result_column)
413 result_column = property(_get_result_column, _set_result_column)
414
414
415
415
416 class ChangesetTable(object):
416 class ChangesetTable(object):
417 """Changeset extensions to SQLAlchemy tables."""
417 """Changeset extensions to SQLAlchemy tables."""
418
418
419 def create_column(self, column, *p, **kw):
419 def create_column(self, column, *p, **kw):
420 """Creates a column.
420 """Creates a column.
421
421
422 The column parameter may be a column definition or the name of
422 The column parameter may be a column definition or the name of
423 a column in this table.
423 a column in this table.
424
424
425 API to :meth:`ChangesetColumn.create`
425 API to :meth:`ChangesetColumn.create`
426
426
427 :param column: Column to be created
427 :param column: Column to be created
428 :type column: Column instance or string
428 :type column: Column instance or string
429 """
429 """
430 if not isinstance(column, sqlalchemy.Column):
430 if not isinstance(column, sqlalchemy.Column):
431 # It's a column name
431 # It's a column name
432 column = getattr(self.c, str(column))
432 column = getattr(self.c, str(column))
433 column.create(table=self, *p, **kw)
433 column.create(table=self, *p, **kw)
434
434
435 def drop_column(self, column, *p, **kw):
435 def drop_column(self, column, *p, **kw):
436 """Drop a column, given its name or definition.
436 """Drop a column, given its name or definition.
437
437
438 API to :meth:`ChangesetColumn.drop`
438 API to :meth:`ChangesetColumn.drop`
439
439
440 :param column: Column to be droped
440 :param column: Column to be droped
441 :type column: Column instance or string
441 :type column: Column instance or string
442 """
442 """
443 if not isinstance(column, sqlalchemy.Column):
443 if not isinstance(column, sqlalchemy.Column):
444 # It's a column name
444 # It's a column name
445 try:
445 try:
446 column = getattr(self.c, str(column))
446 column = getattr(self.c, str(column))
447 except AttributeError:
447 except AttributeError:
448 # That column isn't part of the table. We don't need
448 # That column isn't part of the table. We don't need
449 # its entire definition to drop the column, just its
449 # its entire definition to drop the column, just its
450 # name, so create a dummy column with the same name.
450 # name, so create a dummy column with the same name.
451 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
451 column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
452 column.drop(table=self, *p, **kw)
452 column.drop(table=self, *p, **kw)
453
453
454 def rename(self, name, connection=None, **kwargs):
454 def rename(self, name, connection=None, **kwargs):
455 """Rename this table.
455 """Rename this table.
456
456
457 :param name: New name of the table.
457 :param name: New name of the table.
458 :type name: string
458 :type name: string
459 :param alter_metadata: If True, table will be removed from metadata
459 :param alter_metadata: If True, table will be removed from metadata
460 :type alter_metadata: bool
460 :type alter_metadata: bool
461 :param connection: reuse connection istead of creating new one.
461 :param connection: reuse connection istead of creating new one.
462 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
462 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
463 """
463 """
464 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
464 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
465 engine = self.bind
465 engine = self.bind
466 self.new_name = name
466 self.new_name = name
467 visitorcallable = get_engine_visitor(engine, 'schemachanger')
467 visitorcallable = get_engine_visitor(engine, 'schemachanger')
468 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
468 run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
469
469
470 # Fix metadata registration
470 # Fix metadata registration
471 if self.alter_metadata:
471 if self.alter_metadata:
472 self.name = name
472 self.name = name
473 self.deregister()
473 self.deregister()
474 self._set_parent(self.metadata)
474 self._set_parent(self.metadata)
475
475
476 def _meta_key(self):
476 def _meta_key(self):
477 return sqlalchemy.schema._get_table_key(self.name, self.schema)
477 return sqlalchemy.schema._get_table_key(self.name, self.schema)
478
478
479 def deregister(self):
479 def deregister(self):
480 """Remove this table from its metadata"""
480 """Remove this table from its metadata"""
481 key = self._meta_key()
481 key = self._meta_key()
482 meta = self.metadata
482 meta = self.metadata
483 if key in meta.tables:
483 if key in meta.tables:
484 del meta.tables[key]
484 del meta.tables[key]
485
485
486
486
487 class ChangesetColumn(object):
487 class ChangesetColumn(object):
488 """Changeset extensions to SQLAlchemy columns."""
488 """Changeset extensions to SQLAlchemy columns."""
489
489
490 def alter(self, *p, **k):
490 def alter(self, *p, **k):
491 """Makes a call to :func:`alter_column` for the column this
491 """Makes a call to :func:`alter_column` for the column this
492 method is called on.
492 method is called on.
493 """
493 """
494 if 'table' not in k:
494 if 'table' not in k:
495 k['table'] = self.table
495 k['table'] = self.table
496 if 'engine' not in k:
496 if 'engine' not in k:
497 k['engine'] = k['table'].bind
497 k['engine'] = k['table'].bind
498 return alter_column(self, *p, **k)
498 return alter_column(self, *p, **k)
499
499
500 def create(self, table=None, index_name=None, unique_name=None,
500 def create(self, table=None, index_name=None, unique_name=None,
501 primary_key_name=None, populate_default=True, connection=None, **kwargs):
501 primary_key_name=None, populate_default=True, connection=None, **kwargs):
502 """Create this column in the database.
502 """Create this column in the database.
503
503
504 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
504 Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
505 for most databases.
505 for most databases.
506
506
507 :param table: Table instance to create on.
507 :param table: Table instance to create on.
508 :param index_name: Creates :class:`ChangesetIndex` on this column.
508 :param index_name: Creates :class:`ChangesetIndex` on this column.
509 :param unique_name: Creates :class:\
509 :param unique_name: Creates :class:\
510 `~migrate.changeset.constraint.UniqueConstraint` on this column.
510 `~migrate.changeset.constraint.UniqueConstraint` on this column.
511 :param primary_key_name: Creates :class:\
511 :param primary_key_name: Creates :class:\
512 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
512 `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
513 :param alter_metadata: If True, column will be added to table object.
513 :param alter_metadata: If True, column will be added to table object.
514 :param populate_default: If True, created column will be \
514 :param populate_default: If True, created column will be \
515 populated with defaults
515 populated with defaults
516 :param connection: reuse connection istead of creating new one.
516 :param connection: reuse connection istead of creating new one.
517 :type table: Table instance
517 :type table: Table instance
518 :type index_name: string
518 :type index_name: string
519 :type unique_name: string
519 :type unique_name: string
520 :type primary_key_name: string
520 :type primary_key_name: string
521 :type alter_metadata: bool
521 :type alter_metadata: bool
522 :type populate_default: bool
522 :type populate_default: bool
523 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
523 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
524
524
525 :returns: self
525 :returns: self
526 """
526 """
527 self.populate_default = populate_default
527 self.populate_default = populate_default
528 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
528 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
529 self.index_name = index_name
529 self.index_name = index_name
530 self.unique_name = unique_name
530 self.unique_name = unique_name
531 self.primary_key_name = primary_key_name
531 self.primary_key_name = primary_key_name
532 for cons in ('index_name', 'unique_name', 'primary_key_name'):
532 for cons in ('index_name', 'unique_name', 'primary_key_name'):
533 self._check_sanity_constraints(cons)
533 self._check_sanity_constraints(cons)
534
534
535 if self.alter_metadata:
535 if self.alter_metadata:
536 self.add_to_table(table)
536 self.add_to_table(table)
537 engine = self.table.bind
537 engine = self.table.bind
538 visitorcallable = get_engine_visitor(engine, 'columngenerator')
538 visitorcallable = get_engine_visitor(engine, 'columngenerator')
539 engine._run_visitor(visitorcallable, self, connection, **kwargs)
539 engine._run_visitor(visitorcallable, self, connection, **kwargs)
540
540
541 # TODO: reuse existing connection
541 # TODO: reuse existing connection
542 if self.populate_default and self.default is not None:
542 if self.populate_default and self.default is not None:
543 stmt = table.update().values({self: engine._execute_default(self.default)})
543 stmt = table.update().values({self: engine._execute_default(self.default)})
544 engine.execute(stmt)
544 engine.execute(stmt)
545
545
546 return self
546 return self
547
547
548 def drop(self, table=None, connection=None, **kwargs):
548 def drop(self, table=None, connection=None, **kwargs):
549 """Drop this column from the database, leaving its table intact.
549 """Drop this column from the database, leaving its table intact.
550
550
551 ``ALTER TABLE DROP COLUMN``, for most databases.
551 ``ALTER TABLE DROP COLUMN``, for most databases.
552
552
553 :param alter_metadata: If True, column will be removed from table object.
553 :param alter_metadata: If True, column will be removed from table object.
554 :type alter_metadata: bool
554 :type alter_metadata: bool
555 :param connection: reuse connection istead of creating new one.
555 :param connection: reuse connection istead of creating new one.
556 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
556 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
557 """
557 """
558 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
558 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
559 if table is not None:
559 if table is not None:
560 self.table = table
560 self.table = table
561 engine = self.table.bind
561 engine = self.table.bind
562 if self.alter_metadata:
562 if self.alter_metadata:
563 self.remove_from_table(self.table, unset_table=False)
563 self.remove_from_table(self.table, unset_table=False)
564 visitorcallable = get_engine_visitor(engine, 'columndropper')
564 visitorcallable = get_engine_visitor(engine, 'columndropper')
565 engine._run_visitor(visitorcallable, self, connection, **kwargs)
565 engine._run_visitor(visitorcallable, self, connection, **kwargs)
566 if self.alter_metadata:
566 if self.alter_metadata:
567 self.table = None
567 self.table = None
568 return self
568 return self
569
569
570 def add_to_table(self, table):
570 def add_to_table(self, table):
571 if table is not None and self.table is None:
571 if table is not None and self.table is None:
572 self._set_parent(table)
572 self._set_parent(table)
573
573
574 def _col_name_in_constraint(self, cons, name):
574 def _col_name_in_constraint(self, cons, name):
575 return False
575 return False
576
576
577 def remove_from_table(self, table, unset_table=True):
577 def remove_from_table(self, table, unset_table=True):
578 # TODO: remove primary keys, constraints, etc
578 # TODO: remove primary keys, constraints, etc
579 if unset_table:
579 if unset_table:
580 self.table = None
580 self.table = None
581
581
582 to_drop = set()
582 to_drop = set()
583 for index in table.indexes:
583 for index in table.indexes:
584 columns = []
584 columns = []
585 for col in index.columns:
585 for col in index.columns:
586 if col.name != self.name:
586 if col.name != self.name:
587 columns.append(col)
587 columns.append(col)
588 if columns:
588 if columns:
589 index.columns = columns
589 index.columns = columns
590 else:
590 else:
591 to_drop.add(index)
591 to_drop.add(index)
592 table.indexes = table.indexes - to_drop
592 table.indexes = table.indexes - to_drop
593
593
594 to_drop = set()
594 to_drop = set()
595 for cons in table.constraints:
595 for cons in table.constraints:
596 # TODO: deal with other types of constraint
596 # TODO: deal with other types of constraint
597 if isinstance(cons, (ForeignKeyConstraint,
597 if isinstance(cons, (ForeignKeyConstraint,
598 UniqueConstraint)):
598 UniqueConstraint)):
599 for col_name in cons.columns:
599 for col_name in cons.columns:
600 if not isinstance(col_name, basestring):
600 if not isinstance(col_name, basestring):
601 col_name = col_name.name
601 col_name = col_name.name
602 if self.name == col_name:
602 if self.name == col_name:
603 to_drop.add(cons)
603 to_drop.add(cons)
604 table.constraints = table.constraints - to_drop
604 table.constraints = table.constraints - to_drop
605
605
606 if table.c.contains_column(self):
606 if table.c.contains_column(self):
607 table.c.remove(self)
607 table.c.remove(self)
608
608
609 # TODO: this is fixed in 0.6
609 # TODO: this is fixed in 0.6
610 def copy_fixed(self, **kw):
610 def copy_fixed(self, **kw):
611 """Create a copy of this ``Column``, with all attributes."""
611 """Create a copy of this ``Column``, with all attributes."""
612 return sqlalchemy.Column(self.name, self.type, self.default,
612 return sqlalchemy.Column(self.name, self.type, self.default,
613 key=self.key,
613 key=self.key,
614 primary_key=self.primary_key,
614 primary_key=self.primary_key,
615 nullable=self.nullable,
615 nullable=self.nullable,
616 quote=self.quote,
616 quote=self.quote,
617 index=self.index,
617 index=self.index,
618 unique=self.unique,
618 unique=self.unique,
619 onupdate=self.onupdate,
619 onupdate=self.onupdate,
620 autoincrement=self.autoincrement,
620 autoincrement=self.autoincrement,
621 server_default=self.server_default,
621 server_default=self.server_default,
622 server_onupdate=self.server_onupdate,
622 server_onupdate=self.server_onupdate,
623 *[c.copy(**kw) for c in self.constraints])
623 *[c.copy(**kw) for c in self.constraints])
624
624
625 def _check_sanity_constraints(self, name):
625 def _check_sanity_constraints(self, name):
626 """Check if constraints names are correct"""
626 """Check if constraints names are correct"""
627 obj = getattr(self, name)
627 obj = getattr(self, name)
628 if (getattr(self, name[:-5]) and not obj):
628 if (getattr(self, name[:-5]) and not obj):
629 raise InvalidConstraintError("Column.create() accepts index_name,"
629 raise InvalidConstraintError("Column.create() accepts index_name,"
630 " primary_key_name and unique_name to generate constraints")
630 " primary_key_name and unique_name to generate constraints")
631 if not isinstance(obj, basestring) and obj is not None:
631 if not isinstance(obj, basestring) and obj is not None:
632 raise InvalidConstraintError(
632 raise InvalidConstraintError(
633 "%s argument for column must be constraint name" % name)
633 "%s argument for column must be constraint name" % name)
634
634
635
635
636 class ChangesetIndex(object):
636 class ChangesetIndex(object):
637 """Changeset extensions to SQLAlchemy Indexes."""
637 """Changeset extensions to SQLAlchemy Indexes."""
638
638
639 __visit_name__ = 'index'
639 __visit_name__ = 'index'
640
640
641 def rename(self, name, connection=None, **kwargs):
641 def rename(self, name, connection=None, **kwargs):
642 """Change the name of an index.
642 """Change the name of an index.
643
643
644 :param name: New name of the Index.
644 :param name: New name of the Index.
645 :type name: string
645 :type name: string
646 :param alter_metadata: If True, Index object will be altered.
646 :param alter_metadata: If True, Index object will be altered.
647 :type alter_metadata: bool
647 :type alter_metadata: bool
648 :param connection: reuse connection istead of creating new one.
648 :param connection: reuse connection istead of creating new one.
649 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
649 :type connection: :class:`sqlalchemy.engine.base.Connection` instance
650 """
650 """
651 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
651 self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA)
652 engine = self.table.bind
652 engine = self.table.bind
653 self.new_name = name
653 self.new_name = name
654 visitorcallable = get_engine_visitor(engine, 'schemachanger')
654 visitorcallable = get_engine_visitor(engine, 'schemachanger')
655 engine._run_visitor(visitorcallable, self, connection, **kwargs)
655 engine._run_visitor(visitorcallable, self, connection, **kwargs)
656 if self.alter_metadata:
656 if self.alter_metadata:
657 self.name = name
657 self.name = name
658
658
659
659
660 class ChangesetDefaultClause(object):
660 class ChangesetDefaultClause(object):
661 """Implements comparison between :class:`DefaultClause` instances"""
661 """Implements comparison between :class:`DefaultClause` instances"""
662
662
663 def __eq__(self, other):
663 def __eq__(self, other):
664 if isinstance(other, self.__class__):
664 if isinstance(other, self.__class__):
665 if self.arg == other.arg:
665 if self.arg == other.arg:
666 return True
666 return True
667
667
668 def __ne__(self, other):
668 def __ne__(self, other):
669 return not self.__eq__(other)
669 return not self.__eq__(other)
@@ -1,383 +1,383 b''
1 """
1 """
2 This module provides an external API to the versioning system.
2 This module provides an external API to the versioning system.
3
3
4 .. versionchanged:: 0.6.0
4 .. versionchanged:: 0.6.0
5 :func:`migrate.versioning.api.test` and schema diff functions
5 :func:`migrate.versioning.api.test` and schema diff functions
6 changed order of positional arguments so all accept `url` and `repository`
6 changed order of positional arguments so all accept `url` and `repository`
7 as first arguments.
7 as first arguments.
8
8
9 .. versionchanged:: 0.5.4
9 .. versionchanged:: 0.5.4
10 ``--preview_sql`` displays source file when using SQL scripts.
10 ``--preview_sql`` displays source file when using SQL scripts.
11 If Python script is used, it runs the action with mocked engine and
11 If Python script is used, it runs the action with mocked engine and
12 returns captured SQL statements.
12 returns captured SQL statements.
13
13
14 .. versionchanged:: 0.5.4
14 .. versionchanged:: 0.5.4
15 Deprecated ``--echo`` parameter in favour of new
15 Deprecated ``--echo`` parameter in favour of new
16 :func:`migrate.versioning.util.construct_engine` behavior.
16 :func:`migrate.versioning.util.construct_engine` behavior.
17 """
17 """
18
18
19 # Dear migrate developers,
19 # Dear migrate developers,
20 #
20 #
21 # please do not comment this module using sphinx syntax because its
21 # please do not comment this module using sphinx syntax because its
22 # docstrings are presented as user help and most users cannot
22 # docstrings are presented as user help and most users cannot
23 # interpret sphinx annotated ReStructuredText.
23 # interpret sphinx annotated ReStructuredText.
24 #
24 #
25 # Thanks,
25 # Thanks,
26 # Jan Dittberner
26 # Jan Dittberner
27
27
28 import sys
28 import sys
29 import inspect
29 import inspect
30 import logging
30 import logging
31
31
32 from migrate import exceptions
32 from rhodecode.lib.dbmigrate.migrate import exceptions
33 from migrate.versioning import (repository, schema, version,
33 from rhodecode.lib.dbmigrate.migrate.versioning import (repository, schema, version,
34 script as script_) # command name conflict
34 script as script_) # command name conflict
35 from migrate.versioning.util import catch_known_errors, with_engine
35 from rhodecode.lib.dbmigrate.migrate.versioning.util import catch_known_errors, with_engine
36
36
37
37
38 log = logging.getLogger(__name__)
38 log = logging.getLogger(__name__)
39 command_desc = {
39 command_desc = {
40 'help': 'displays help on a given command',
40 'help': 'displays help on a given command',
41 'create': 'create an empty repository at the specified path',
41 'create': 'create an empty repository at the specified path',
42 'script': 'create an empty change Python script',
42 'script': 'create an empty change Python script',
43 'script_sql': 'create empty change SQL scripts for given database',
43 'script_sql': 'create empty change SQL scripts for given database',
44 'version': 'display the latest version available in a repository',
44 'version': 'display the latest version available in a repository',
45 'db_version': 'show the current version of the repository under version control',
45 'db_version': 'show the current version of the repository under version control',
46 'source': 'display the Python code for a particular version in this repository',
46 'source': 'display the Python code for a particular version in this repository',
47 'version_control': 'mark a database as under this repository\'s version control',
47 'version_control': 'mark a database as under this repository\'s version control',
48 'upgrade': 'upgrade a database to a later version',
48 'upgrade': 'upgrade a database to a later version',
49 'downgrade': 'downgrade a database to an earlier version',
49 'downgrade': 'downgrade a database to an earlier version',
50 'drop_version_control': 'removes version control from a database',
50 'drop_version_control': 'removes version control from a database',
51 'manage': 'creates a Python script that runs Migrate with a set of default values',
51 'manage': 'creates a Python script that runs Migrate with a set of default values',
52 'test': 'performs the upgrade and downgrade command on the given database',
52 'test': 'performs the upgrade and downgrade command on the given database',
53 'compare_model_to_db': 'compare MetaData against the current database state',
53 'compare_model_to_db': 'compare MetaData against the current database state',
54 'create_model': 'dump the current database as a Python model to stdout',
54 'create_model': 'dump the current database as a Python model to stdout',
55 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
55 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
56 'update_db_from_model': 'modify the database to match the structure of the current MetaData',
56 'update_db_from_model': 'modify the database to match the structure of the current MetaData',
57 }
57 }
58 __all__ = command_desc.keys()
58 __all__ = command_desc.keys()
59
59
60 Repository = repository.Repository
60 Repository = repository.Repository
61 ControlledSchema = schema.ControlledSchema
61 ControlledSchema = schema.ControlledSchema
62 VerNum = version.VerNum
62 VerNum = version.VerNum
63 PythonScript = script_.PythonScript
63 PythonScript = script_.PythonScript
64 SqlScript = script_.SqlScript
64 SqlScript = script_.SqlScript
65
65
66
66
67 # deprecated
67 # deprecated
68 def help(cmd=None, **opts):
68 def help(cmd=None, **opts):
69 """%prog help COMMAND
69 """%prog help COMMAND
70
70
71 Displays help on a given command.
71 Displays help on a given command.
72 """
72 """
73 if cmd is None:
73 if cmd is None:
74 raise exceptions.UsageError(None)
74 raise exceptions.UsageError(None)
75 try:
75 try:
76 func = globals()[cmd]
76 func = globals()[cmd]
77 except:
77 except:
78 raise exceptions.UsageError(
78 raise exceptions.UsageError(
79 "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
79 "'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
80 ret = func.__doc__
80 ret = func.__doc__
81 if sys.argv[0]:
81 if sys.argv[0]:
82 ret = ret.replace('%prog', sys.argv[0])
82 ret = ret.replace('%prog', sys.argv[0])
83 return ret
83 return ret
84
84
85 @catch_known_errors
85 @catch_known_errors
86 def create(repository, name, **opts):
86 def create(repository, name, **opts):
87 """%prog create REPOSITORY_PATH NAME [--table=TABLE]
87 """%prog create REPOSITORY_PATH NAME [--table=TABLE]
88
88
89 Create an empty repository at the specified path.
89 Create an empty repository at the specified path.
90
90
91 You can specify the version_table to be used; by default, it is
91 You can specify the version_table to be used; by default, it is
92 'migrate_version'. This table is created in all version-controlled
92 'migrate_version'. This table is created in all version-controlled
93 databases.
93 databases.
94 """
94 """
95 repo_path = Repository.create(repository, name, **opts)
95 repo_path = Repository.create(repository, name, **opts)
96
96
97
97
98 @catch_known_errors
98 @catch_known_errors
99 def script(description, repository, **opts):
99 def script(description, repository, **opts):
100 """%prog script DESCRIPTION REPOSITORY_PATH
100 """%prog script DESCRIPTION REPOSITORY_PATH
101
101
102 Create an empty change script using the next unused version number
102 Create an empty change script using the next unused version number
103 appended with the given description.
103 appended with the given description.
104
104
105 For instance, manage.py script "Add initial tables" creates:
105 For instance, manage.py script "Add initial tables" creates:
106 repository/versions/001_Add_initial_tables.py
106 repository/versions/001_Add_initial_tables.py
107 """
107 """
108 repo = Repository(repository)
108 repo = Repository(repository)
109 repo.create_script(description, **opts)
109 repo.create_script(description, **opts)
110
110
111
111
112 @catch_known_errors
112 @catch_known_errors
113 def script_sql(database, repository, **opts):
113 def script_sql(database, repository, **opts):
114 """%prog script_sql DATABASE REPOSITORY_PATH
114 """%prog script_sql DATABASE REPOSITORY_PATH
115
115
116 Create empty change SQL scripts for given DATABASE, where DATABASE
116 Create empty change SQL scripts for given DATABASE, where DATABASE
117 is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.)
117 is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.)
118 or generic ('default').
118 or generic ('default').
119
119
120 For instance, manage.py script_sql postgres creates:
120 For instance, manage.py script_sql postgres creates:
121 repository/versions/001_postgres_upgrade.sql and
121 repository/versions/001_postgres_upgrade.sql and
122 repository/versions/001_postgres_postgres.sql
122 repository/versions/001_postgres_postgres.sql
123 """
123 """
124 repo = Repository(repository)
124 repo = Repository(repository)
125 repo.create_script_sql(database, **opts)
125 repo.create_script_sql(database, **opts)
126
126
127
127
128 def version(repository, **opts):
128 def version(repository, **opts):
129 """%prog version REPOSITORY_PATH
129 """%prog version REPOSITORY_PATH
130
130
131 Display the latest version available in a repository.
131 Display the latest version available in a repository.
132 """
132 """
133 repo = Repository(repository)
133 repo = Repository(repository)
134 return repo.latest
134 return repo.latest
135
135
136
136
137 @with_engine
137 @with_engine
138 def db_version(url, repository, **opts):
138 def db_version(url, repository, **opts):
139 """%prog db_version URL REPOSITORY_PATH
139 """%prog db_version URL REPOSITORY_PATH
140
140
141 Show the current version of the repository with the given
141 Show the current version of the repository with the given
142 connection string, under version control of the specified
142 connection string, under version control of the specified
143 repository.
143 repository.
144
144
145 The url should be any valid SQLAlchemy connection string.
145 The url should be any valid SQLAlchemy connection string.
146 """
146 """
147 engine = opts.pop('engine')
147 engine = opts.pop('engine')
148 schema = ControlledSchema(engine, repository)
148 schema = ControlledSchema(engine, repository)
149 return schema.version
149 return schema.version
150
150
151
151
152 def source(version, dest=None, repository=None, **opts):
152 def source(version, dest=None, repository=None, **opts):
153 """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
153 """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
154
154
155 Display the Python code for a particular version in this
155 Display the Python code for a particular version in this
156 repository. Save it to the file at DESTINATION or, if omitted,
156 repository. Save it to the file at DESTINATION or, if omitted,
157 send to stdout.
157 send to stdout.
158 """
158 """
159 if repository is None:
159 if repository is None:
160 raise exceptions.UsageError("A repository must be specified")
160 raise exceptions.UsageError("A repository must be specified")
161 repo = Repository(repository)
161 repo = Repository(repository)
162 ret = repo.version(version).script().source()
162 ret = repo.version(version).script().source()
163 if dest is not None:
163 if dest is not None:
164 dest = open(dest, 'w')
164 dest = open(dest, 'w')
165 dest.write(ret)
165 dest.write(ret)
166 dest.close()
166 dest.close()
167 ret = None
167 ret = None
168 return ret
168 return ret
169
169
170
170
171 def upgrade(url, repository, version=None, **opts):
171 def upgrade(url, repository, version=None, **opts):
172 """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
172 """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
173
173
174 Upgrade a database to a later version.
174 Upgrade a database to a later version.
175
175
176 This runs the upgrade() function defined in your change scripts.
176 This runs the upgrade() function defined in your change scripts.
177
177
178 By default, the database is updated to the latest available
178 By default, the database is updated to the latest available
179 version. You may specify a version instead, if you wish.
179 version. You may specify a version instead, if you wish.
180
180
181 You may preview the Python or SQL code to be executed, rather than
181 You may preview the Python or SQL code to be executed, rather than
182 actually executing it, using the appropriate 'preview' option.
182 actually executing it, using the appropriate 'preview' option.
183 """
183 """
184 err = "Cannot upgrade a database of version %s to version %s. "\
184 err = "Cannot upgrade a database of version %s to version %s. "\
185 "Try 'downgrade' instead."
185 "Try 'downgrade' instead."
186 return _migrate(url, repository, version, upgrade=True, err=err, **opts)
186 return _migrate(url, repository, version, upgrade=True, err=err, **opts)
187
187
188
188
189 def downgrade(url, repository, version, **opts):
189 def downgrade(url, repository, version, **opts):
190 """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
190 """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
191
191
192 Downgrade a database to an earlier version.
192 Downgrade a database to an earlier version.
193
193
194 This is the reverse of upgrade; this runs the downgrade() function
194 This is the reverse of upgrade; this runs the downgrade() function
195 defined in your change scripts.
195 defined in your change scripts.
196
196
197 You may preview the Python or SQL code to be executed, rather than
197 You may preview the Python or SQL code to be executed, rather than
198 actually executing it, using the appropriate 'preview' option.
198 actually executing it, using the appropriate 'preview' option.
199 """
199 """
200 err = "Cannot downgrade a database of version %s to version %s. "\
200 err = "Cannot downgrade a database of version %s to version %s. "\
201 "Try 'upgrade' instead."
201 "Try 'upgrade' instead."
202 return _migrate(url, repository, version, upgrade=False, err=err, **opts)
202 return _migrate(url, repository, version, upgrade=False, err=err, **opts)
203
203
204 @with_engine
204 @with_engine
205 def test(url, repository, **opts):
205 def test(url, repository, **opts):
206 """%prog test URL REPOSITORY_PATH [VERSION]
206 """%prog test URL REPOSITORY_PATH [VERSION]
207
207
208 Performs the upgrade and downgrade option on the given
208 Performs the upgrade and downgrade option on the given
209 database. This is not a real test and may leave the database in a
209 database. This is not a real test and may leave the database in a
210 bad state. You should therefore better run the test on a copy of
210 bad state. You should therefore better run the test on a copy of
211 your database.
211 your database.
212 """
212 """
213 engine = opts.pop('engine')
213 engine = opts.pop('engine')
214 repos = Repository(repository)
214 repos = Repository(repository)
215 script = repos.version(None).script()
215 script = repos.version(None).script()
216
216
217 # Upgrade
217 # Upgrade
218 log.info("Upgrading...")
218 log.info("Upgrading...")
219 script.run(engine, 1)
219 script.run(engine, 1)
220 log.info("done")
220 log.info("done")
221
221
222 log.info("Downgrading...")
222 log.info("Downgrading...")
223 script.run(engine, -1)
223 script.run(engine, -1)
224 log.info("done")
224 log.info("done")
225 log.info("Success")
225 log.info("Success")
226
226
227
227
228 @with_engine
228 @with_engine
229 def version_control(url, repository, version=None, **opts):
229 def version_control(url, repository, version=None, **opts):
230 """%prog version_control URL REPOSITORY_PATH [VERSION]
230 """%prog version_control URL REPOSITORY_PATH [VERSION]
231
231
232 Mark a database as under this repository's version control.
232 Mark a database as under this repository's version control.
233
233
234 Once a database is under version control, schema changes should
234 Once a database is under version control, schema changes should
235 only be done via change scripts in this repository.
235 only be done via change scripts in this repository.
236
236
237 This creates the table version_table in the database.
237 This creates the table version_table in the database.
238
238
239 The url should be any valid SQLAlchemy connection string.
239 The url should be any valid SQLAlchemy connection string.
240
240
241 By default, the database begins at version 0 and is assumed to be
241 By default, the database begins at version 0 and is assumed to be
242 empty. If the database is not empty, you may specify a version at
242 empty. If the database is not empty, you may specify a version at
243 which to begin instead. No attempt is made to verify this
243 which to begin instead. No attempt is made to verify this
244 version's correctness - the database schema is expected to be
244 version's correctness - the database schema is expected to be
245 identical to what it would be if the database were created from
245 identical to what it would be if the database were created from
246 scratch.
246 scratch.
247 """
247 """
248 engine = opts.pop('engine')
248 engine = opts.pop('engine')
249 ControlledSchema.create(engine, repository, version)
249 ControlledSchema.create(engine, repository, version)
250
250
251
251
252 @with_engine
252 @with_engine
253 def drop_version_control(url, repository, **opts):
253 def drop_version_control(url, repository, **opts):
254 """%prog drop_version_control URL REPOSITORY_PATH
254 """%prog drop_version_control URL REPOSITORY_PATH
255
255
256 Removes version control from a database.
256 Removes version control from a database.
257 """
257 """
258 engine = opts.pop('engine')
258 engine = opts.pop('engine')
259 schema = ControlledSchema(engine, repository)
259 schema = ControlledSchema(engine, repository)
260 schema.drop()
260 schema.drop()
261
261
262
262
263 def manage(file, **opts):
263 def manage(file, **opts):
264 """%prog manage FILENAME [VARIABLES...]
264 """%prog manage FILENAME [VARIABLES...]
265
265
266 Creates a script that runs Migrate with a set of default values.
266 Creates a script that runs Migrate with a set of default values.
267
267
268 For example::
268 For example::
269
269
270 %prog manage manage.py --repository=/path/to/repository \
270 %prog manage manage.py --repository=/path/to/repository \
271 --url=sqlite:///project.db
271 --url=sqlite:///project.db
272
272
273 would create the script manage.py. The following two commands
273 would create the script manage.py. The following two commands
274 would then have exactly the same results::
274 would then have exactly the same results::
275
275
276 python manage.py version
276 python manage.py version
277 %prog version --repository=/path/to/repository
277 %prog version --repository=/path/to/repository
278 """
278 """
279 Repository.create_manage_file(file, **opts)
279 Repository.create_manage_file(file, **opts)
280
280
281
281
282 @with_engine
282 @with_engine
283 def compare_model_to_db(url, repository, model, **opts):
283 def compare_model_to_db(url, repository, model, **opts):
284 """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
284 """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
285
285
286 Compare the current model (assumed to be a module level variable
286 Compare the current model (assumed to be a module level variable
287 of type sqlalchemy.MetaData) against the current database.
287 of type sqlalchemy.MetaData) against the current database.
288
288
289 NOTE: This is EXPERIMENTAL.
289 NOTE: This is EXPERIMENTAL.
290 """ # TODO: get rid of EXPERIMENTAL label
290 """ # TODO: get rid of EXPERIMENTAL label
291 engine = opts.pop('engine')
291 engine = opts.pop('engine')
292 return ControlledSchema.compare_model_to_db(engine, model, repository)
292 return ControlledSchema.compare_model_to_db(engine, model, repository)
293
293
294
294
295 @with_engine
295 @with_engine
296 def create_model(url, repository, **opts):
296 def create_model(url, repository, **opts):
297 """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
297 """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
298
298
299 Dump the current database as a Python model to stdout.
299 Dump the current database as a Python model to stdout.
300
300
301 NOTE: This is EXPERIMENTAL.
301 NOTE: This is EXPERIMENTAL.
302 """ # TODO: get rid of EXPERIMENTAL label
302 """ # TODO: get rid of EXPERIMENTAL label
303 engine = opts.pop('engine')
303 engine = opts.pop('engine')
304 declarative = opts.get('declarative', False)
304 declarative = opts.get('declarative', False)
305 return ControlledSchema.create_model(engine, repository, declarative)
305 return ControlledSchema.create_model(engine, repository, declarative)
306
306
307
307
308 @catch_known_errors
308 @catch_known_errors
309 @with_engine
309 @with_engine
310 def make_update_script_for_model(url, repository, oldmodel, model, **opts):
310 def make_update_script_for_model(url, repository, oldmodel, model, **opts):
311 """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
311 """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
312
312
313 Create a script changing the old Python model to the new (current)
313 Create a script changing the old Python model to the new (current)
314 Python model, sending to stdout.
314 Python model, sending to stdout.
315
315
316 NOTE: This is EXPERIMENTAL.
316 NOTE: This is EXPERIMENTAL.
317 """ # TODO: get rid of EXPERIMENTAL label
317 """ # TODO: get rid of EXPERIMENTAL label
318 engine = opts.pop('engine')
318 engine = opts.pop('engine')
319 return PythonScript.make_update_script_for_model(
319 return PythonScript.make_update_script_for_model(
320 engine, oldmodel, model, repository, **opts)
320 engine, oldmodel, model, repository, **opts)
321
321
322
322
323 @with_engine
323 @with_engine
324 def update_db_from_model(url, repository, model, **opts):
324 def update_db_from_model(url, repository, model, **opts):
325 """%prog update_db_from_model URL REPOSITORY_PATH MODEL
325 """%prog update_db_from_model URL REPOSITORY_PATH MODEL
326
326
327 Modify the database to match the structure of the current Python
327 Modify the database to match the structure of the current Python
328 model. This also sets the db_version number to the latest in the
328 model. This also sets the db_version number to the latest in the
329 repository.
329 repository.
330
330
331 NOTE: This is EXPERIMENTAL.
331 NOTE: This is EXPERIMENTAL.
332 """ # TODO: get rid of EXPERIMENTAL label
332 """ # TODO: get rid of EXPERIMENTAL label
333 engine = opts.pop('engine')
333 engine = opts.pop('engine')
334 schema = ControlledSchema(engine, repository)
334 schema = ControlledSchema(engine, repository)
335 schema.update_db_from_model(model)
335 schema.update_db_from_model(model)
336
336
337 @with_engine
337 @with_engine
338 def _migrate(url, repository, version, upgrade, err, **opts):
338 def _migrate(url, repository, version, upgrade, err, **opts):
339 engine = opts.pop('engine')
339 engine = opts.pop('engine')
340 url = str(engine.url)
340 url = str(engine.url)
341 schema = ControlledSchema(engine, repository)
341 schema = ControlledSchema(engine, repository)
342 version = _migrate_version(schema, version, upgrade, err)
342 version = _migrate_version(schema, version, upgrade, err)
343
343
344 changeset = schema.changeset(version)
344 changeset = schema.changeset(version)
345 for ver, change in changeset:
345 for ver, change in changeset:
346 nextver = ver + changeset.step
346 nextver = ver + changeset.step
347 log.info('%s -> %s... ', ver, nextver)
347 log.info('%s -> %s... ', ver, nextver)
348
348
349 if opts.get('preview_sql'):
349 if opts.get('preview_sql'):
350 if isinstance(change, PythonScript):
350 if isinstance(change, PythonScript):
351 log.info(change.preview_sql(url, changeset.step, **opts))
351 log.info(change.preview_sql(url, changeset.step, **opts))
352 elif isinstance(change, SqlScript):
352 elif isinstance(change, SqlScript):
353 log.info(change.source())
353 log.info(change.source())
354
354
355 elif opts.get('preview_py'):
355 elif opts.get('preview_py'):
356 if not isinstance(change, PythonScript):
356 if not isinstance(change, PythonScript):
357 raise exceptions.UsageError("Python source can be only displayed"
357 raise exceptions.UsageError("Python source can be only displayed"
358 " for python migration files")
358 " for python migration files")
359 source_ver = max(ver, nextver)
359 source_ver = max(ver, nextver)
360 module = schema.repository.version(source_ver).script().module
360 module = schema.repository.version(source_ver).script().module
361 funcname = upgrade and "upgrade" or "downgrade"
361 funcname = upgrade and "upgrade" or "downgrade"
362 func = getattr(module, funcname)
362 func = getattr(module, funcname)
363 log.info(inspect.getsource(func))
363 log.info(inspect.getsource(func))
364 else:
364 else:
365 schema.runchange(ver, change, changeset.step)
365 schema.runchange(ver, change, changeset.step)
366 log.info('done')
366 log.info('done')
367
367
368
368
369 def _migrate_version(schema, version, upgrade, err):
369 def _migrate_version(schema, version, upgrade, err):
370 if version is None:
370 if version is None:
371 return version
371 return version
372 # Version is specified: ensure we're upgrading in the right direction
372 # Version is specified: ensure we're upgrading in the right direction
373 # (current version < target version for upgrading; reverse for down)
373 # (current version < target version for upgrading; reverse for down)
374 version = VerNum(version)
374 version = VerNum(version)
375 cur = schema.version
375 cur = schema.version
376 if upgrade is not None:
376 if upgrade is not None:
377 if upgrade:
377 if upgrade:
378 direction = cur <= version
378 direction = cur <= version
379 else:
379 else:
380 direction = cur >= version
380 direction = cur >= version
381 if not direction:
381 if not direction:
382 raise exceptions.KnownError(err % (cur, version))
382 raise exceptions.KnownError(err % (cur, version))
383 return version
383 return version
@@ -1,27 +1,27 b''
1 """
1 """
2 Configuration parser module.
2 Configuration parser module.
3 """
3 """
4
4
5 from ConfigParser import ConfigParser
5 from ConfigParser import ConfigParser
6
6
7 from migrate.versioning.config import *
7 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
8 from migrate.versioning import pathed
8 from rhodecode.lib.dbmigrate.migrate.versioning import pathed
9
9
10
10
11 class Parser(ConfigParser):
11 class Parser(ConfigParser):
12 """A project configuration file."""
12 """A project configuration file."""
13
13
14 def to_dict(self, sections=None):
14 def to_dict(self, sections=None):
15 """It's easier to access config values like dictionaries"""
15 """It's easier to access config values like dictionaries"""
16 return self._sections
16 return self._sections
17
17
18
18
19 class Config(pathed.Pathed, Parser):
19 class Config(pathed.Pathed, Parser):
20 """Configuration class."""
20 """Configuration class."""
21
21
22 def __init__(self, path, *p, **k):
22 def __init__(self, path, *p, **k):
23 """Confirm the config file exists; read it."""
23 """Confirm the config file exists; read it."""
24 self.require_found(path)
24 self.require_found(path)
25 pathed.Pathed.__init__(self, path)
25 pathed.Pathed.__init__(self, path)
26 Parser.__init__(self, *p, **k)
26 Parser.__init__(self, *p, **k)
27 self.read(path)
27 self.read(path)
@@ -1,254 +1,253 b''
1 """
1 """
2 Code to generate a Python model from a database or differences
2 Code to generate a Python model from a database or differences
3 between a model and database.
3 between a model and database.
4
4
5 Some of this is borrowed heavily from the AutoCode project at:
5 Some of this is borrowed heavily from the AutoCode project at:
6 http://code.google.com/p/sqlautocode/
6 http://code.google.com/p/sqlautocode/
7 """
7 """
8
8
9 import sys
9 import sys
10 import logging
10 import logging
11
11
12 import sqlalchemy
12 import sqlalchemy
13
13
14 import migrate
14 from rhodecode.lib.dbmigrate import migrate
15 import migrate.changeset
15 from rhodecode.lib.dbmigrate.migrate import changeset
16
17
16
18 log = logging.getLogger(__name__)
17 log = logging.getLogger(__name__)
19 HEADER = """
18 HEADER = """
20 ## File autogenerated by genmodel.py
19 ## File autogenerated by genmodel.py
21
20
22 from sqlalchemy import *
21 from sqlalchemy import *
23 meta = MetaData()
22 meta = MetaData()
24 """
23 """
25
24
26 DECLARATIVE_HEADER = """
25 DECLARATIVE_HEADER = """
27 ## File autogenerated by genmodel.py
26 ## File autogenerated by genmodel.py
28
27
29 from sqlalchemy import *
28 from sqlalchemy import *
30 from sqlalchemy.ext import declarative
29 from sqlalchemy.ext import declarative
31
30
32 Base = declarative.declarative_base()
31 Base = declarative.declarative_base()
33 """
32 """
34
33
35
34
36 class ModelGenerator(object):
35 class ModelGenerator(object):
37
36
38 def __init__(self, diff, engine, declarative=False):
37 def __init__(self, diff, engine, declarative=False):
39 self.diff = diff
38 self.diff = diff
40 self.engine = engine
39 self.engine = engine
41 self.declarative = declarative
40 self.declarative = declarative
42
41
43 def column_repr(self, col):
42 def column_repr(self, col):
44 kwarg = []
43 kwarg = []
45 if col.key != col.name:
44 if col.key != col.name:
46 kwarg.append('key')
45 kwarg.append('key')
47 if col.primary_key:
46 if col.primary_key:
48 col.primary_key = True # otherwise it dumps it as 1
47 col.primary_key = True # otherwise it dumps it as 1
49 kwarg.append('primary_key')
48 kwarg.append('primary_key')
50 if not col.nullable:
49 if not col.nullable:
51 kwarg.append('nullable')
50 kwarg.append('nullable')
52 if col.onupdate:
51 if col.onupdate:
53 kwarg.append('onupdate')
52 kwarg.append('onupdate')
54 if col.default:
53 if col.default:
55 if col.primary_key:
54 if col.primary_key:
56 # I found that PostgreSQL automatically creates a
55 # I found that PostgreSQL automatically creates a
57 # default value for the sequence, but let's not show
56 # default value for the sequence, but let's not show
58 # that.
57 # that.
59 pass
58 pass
60 else:
59 else:
61 kwarg.append('default')
60 kwarg.append('default')
62 ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
61 ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
63
62
64 # crs: not sure if this is good idea, but it gets rid of extra
63 # crs: not sure if this is good idea, but it gets rid of extra
65 # u''
64 # u''
66 name = col.name.encode('utf8')
65 name = col.name.encode('utf8')
67
66
68 type_ = col.type
67 type_ = col.type
69 for cls in col.type.__class__.__mro__:
68 for cls in col.type.__class__.__mro__:
70 if cls.__module__ == 'sqlalchemy.types' and \
69 if cls.__module__ == 'sqlalchemy.types' and \
71 not cls.__name__.isupper():
70 not cls.__name__.isupper():
72 if cls is not type_.__class__:
71 if cls is not type_.__class__:
73 type_ = cls()
72 type_ = cls()
74 break
73 break
75
74
76 data = {
75 data = {
77 'name': name,
76 'name': name,
78 'type': type_,
77 'type': type_,
79 'constraints': ', '.join([repr(cn) for cn in col.constraints]),
78 'constraints': ', '.join([repr(cn) for cn in col.constraints]),
80 'args': ks and ks or ''}
79 'args': ks and ks or ''}
81
80
82 if data['constraints']:
81 if data['constraints']:
83 if data['args']:
82 if data['args']:
84 data['args'] = ',' + data['args']
83 data['args'] = ',' + data['args']
85
84
86 if data['constraints'] or data['args']:
85 if data['constraints'] or data['args']:
87 data['maybeComma'] = ','
86 data['maybeComma'] = ','
88 else:
87 else:
89 data['maybeComma'] = ''
88 data['maybeComma'] = ''
90
89
91 commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
90 commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
92 commonStuff = commonStuff.strip()
91 commonStuff = commonStuff.strip()
93 data['commonStuff'] = commonStuff
92 data['commonStuff'] = commonStuff
94 if self.declarative:
93 if self.declarative:
95 return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
94 return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
96 else:
95 else:
97 return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
96 return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
98
97
99 def getTableDefn(self, table):
98 def getTableDefn(self, table):
100 out = []
99 out = []
101 tableName = table.name
100 tableName = table.name
102 if self.declarative:
101 if self.declarative:
103 out.append("class %(table)s(Base):" % {'table': tableName})
102 out.append("class %(table)s(Base):" % {'table': tableName})
104 out.append(" __tablename__ = '%(table)s'" % {'table': tableName})
103 out.append(" __tablename__ = '%(table)s'" % {'table': tableName})
105 for col in table.columns:
104 for col in table.columns:
106 out.append(" %s" % self.column_repr(col))
105 out.append(" %s" % self.column_repr(col))
107 else:
106 else:
108 out.append("%(table)s = Table('%(table)s', meta," % \
107 out.append("%(table)s = Table('%(table)s', meta," % \
109 {'table': tableName})
108 {'table': tableName})
110 for col in table.columns:
109 for col in table.columns:
111 out.append(" %s," % self.column_repr(col))
110 out.append(" %s," % self.column_repr(col))
112 out.append(")")
111 out.append(")")
113 return out
112 return out
114
113
115 def _get_tables(self,missingA=False,missingB=False,modified=False):
114 def _get_tables(self, missingA=False, missingB=False, modified=False):
116 to_process = []
115 to_process = []
117 for bool_,names,metadata in (
116 for bool_, names, metadata in (
118 (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
117 (missingA, self.diff.tables_missing_from_A, self.diff.metadataB),
119 (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
118 (missingB, self.diff.tables_missing_from_B, self.diff.metadataA),
120 (modified,self.diff.tables_different,self.diff.metadataA),
119 (modified, self.diff.tables_different, self.diff.metadataA),
121 ):
120 ):
122 if bool_:
121 if bool_:
123 for name in names:
122 for name in names:
124 yield metadata.tables.get(name)
123 yield metadata.tables.get(name)
125
124
126 def toPython(self):
125 def toPython(self):
127 """Assume database is current and model is empty."""
126 """Assume database is current and model is empty."""
128 out = []
127 out = []
129 if self.declarative:
128 if self.declarative:
130 out.append(DECLARATIVE_HEADER)
129 out.append(DECLARATIVE_HEADER)
131 else:
130 else:
132 out.append(HEADER)
131 out.append(HEADER)
133 out.append("")
132 out.append("")
134 for table in self._get_tables(missingA=True):
133 for table in self._get_tables(missingA=True):
135 out.extend(self.getTableDefn(table))
134 out.extend(self.getTableDefn(table))
136 out.append("")
135 out.append("")
137 return '\n'.join(out)
136 return '\n'.join(out)
138
137
139 def toUpgradeDowngradePython(self, indent=' '):
138 def toUpgradeDowngradePython(self, indent=' '):
140 ''' Assume model is most current and database is out-of-date. '''
139 ''' Assume model is most current and database is out-of-date. '''
141 decls = ['from migrate.changeset import schema',
140 decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema',
142 'meta = MetaData()']
141 'meta = MetaData()']
143 for table in self._get_tables(
142 for table in self._get_tables(
144 missingA=True,missingB=True,modified=True
143 missingA=True, missingB=True, modified=True
145 ):
144 ):
146 decls.extend(self.getTableDefn(table))
145 decls.extend(self.getTableDefn(table))
147
146
148 upgradeCommands, downgradeCommands = [], []
147 upgradeCommands, downgradeCommands = [], []
149 for tableName in self.diff.tables_missing_from_A:
148 for tableName in self.diff.tables_missing_from_A:
150 upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
149 upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
151 downgradeCommands.append("%(table)s.create()" % \
150 downgradeCommands.append("%(table)s.create()" % \
152 {'table': tableName})
151 {'table': tableName})
153 for tableName in self.diff.tables_missing_from_B:
152 for tableName in self.diff.tables_missing_from_B:
154 upgradeCommands.append("%(table)s.create()" % {'table': tableName})
153 upgradeCommands.append("%(table)s.create()" % {'table': tableName})
155 downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
154 downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
156
155
157 for tableName in self.diff.tables_different:
156 for tableName in self.diff.tables_different:
158 dbTable = self.diff.metadataB.tables[tableName]
157 dbTable = self.diff.metadataB.tables[tableName]
159 missingInDatabase, missingInModel, diffDecl = \
158 missingInDatabase, missingInModel, diffDecl = \
160 self.diff.colDiffs[tableName]
159 self.diff.colDiffs[tableName]
161 for col in missingInDatabase:
160 for col in missingInDatabase:
162 upgradeCommands.append('%s.columns[%r].create()' % (
161 upgradeCommands.append('%s.columns[%r].create()' % (
163 modelTable, col.name))
162 modelTable, col.name))
164 downgradeCommands.append('%s.columns[%r].drop()' % (
163 downgradeCommands.append('%s.columns[%r].drop()' % (
165 modelTable, col.name))
164 modelTable, col.name))
166 for col in missingInModel:
165 for col in missingInModel:
167 upgradeCommands.append('%s.columns[%r].drop()' % (
166 upgradeCommands.append('%s.columns[%r].drop()' % (
168 modelTable, col.name))
167 modelTable, col.name))
169 downgradeCommands.append('%s.columns[%r].create()' % (
168 downgradeCommands.append('%s.columns[%r].create()' % (
170 modelTable, col.name))
169 modelTable, col.name))
171 for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
170 for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
172 upgradeCommands.append(
171 upgradeCommands.append(
173 'assert False, "Can\'t alter columns: %s:%s=>%s"',
172 'assert False, "Can\'t alter columns: %s:%s=>%s"',
174 modelTable, modelCol.name, databaseCol.name)
173 modelTable, modelCol.name, databaseCol.name)
175 downgradeCommands.append(
174 downgradeCommands.append(
176 'assert False, "Can\'t alter columns: %s:%s=>%s"',
175 'assert False, "Can\'t alter columns: %s:%s=>%s"',
177 modelTable, modelCol.name, databaseCol.name)
176 modelTable, modelCol.name, databaseCol.name)
178 pre_command = ' meta.bind = migrate_engine'
177 pre_command = ' meta.bind = migrate_engine'
179
178
180 return (
179 return (
181 '\n'.join(decls),
180 '\n'.join(decls),
182 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
181 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
183 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
182 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
184
183
185 def _db_can_handle_this_change(self,td):
184 def _db_can_handle_this_change(self, td):
186 if (td.columns_missing_from_B
185 if (td.columns_missing_from_B
187 and not td.columns_missing_from_A
186 and not td.columns_missing_from_A
188 and not td.columns_different):
187 and not td.columns_different):
189 # Even sqlite can handle this.
188 # Even sqlite can handle this.
190 return True
189 return True
191 else:
190 else:
192 return not self.engine.url.drivername.startswith('sqlite')
191 return not self.engine.url.drivername.startswith('sqlite')
193
192
194 def applyModel(self):
193 def applyModel(self):
195 """Apply model to current database."""
194 """Apply model to current database."""
196
195
197 meta = sqlalchemy.MetaData(self.engine)
196 meta = sqlalchemy.MetaData(self.engine)
198
197
199 for table in self._get_tables(missingA=True):
198 for table in self._get_tables(missingA=True):
200 table = table.tometadata(meta)
199 table = table.tometadata(meta)
201 table.drop()
200 table.drop()
202 for table in self._get_tables(missingB=True):
201 for table in self._get_tables(missingB=True):
203 table = table.tometadata(meta)
202 table = table.tometadata(meta)
204 table.create()
203 table.create()
205 for modelTable in self._get_tables(modified=True):
204 for modelTable in self._get_tables(modified=True):
206 tableName = modelTable.name
205 tableName = modelTable.name
207 modelTable = modelTable.tometadata(meta)
206 modelTable = modelTable.tometadata(meta)
208 dbTable = self.diff.metadataB.tables[tableName]
207 dbTable = self.diff.metadataB.tables[tableName]
209
208
210 td = self.diff.tables_different[tableName]
209 td = self.diff.tables_different[tableName]
211
210
212 if self._db_can_handle_this_change(td):
211 if self._db_can_handle_this_change(td):
213
212
214 for col in td.columns_missing_from_B:
213 for col in td.columns_missing_from_B:
215 modelTable.columns[col].create()
214 modelTable.columns[col].create()
216 for col in td.columns_missing_from_A:
215 for col in td.columns_missing_from_A:
217 dbTable.columns[col].drop()
216 dbTable.columns[col].drop()
218 # XXX handle column changes here.
217 # XXX handle column changes here.
219 else:
218 else:
220 # Sqlite doesn't support drop column, so you have to
219 # Sqlite doesn't support drop column, so you have to
221 # do more: create temp table, copy data to it, drop
220 # do more: create temp table, copy data to it, drop
222 # old table, create new table, copy data back.
221 # old table, create new table, copy data back.
223 #
222 #
224 # I wonder if this is guaranteed to be unique?
223 # I wonder if this is guaranteed to be unique?
225 tempName = '_temp_%s' % modelTable.name
224 tempName = '_temp_%s' % modelTable.name
226
225
227 def getCopyStatement():
226 def getCopyStatement():
228 preparer = self.engine.dialect.preparer
227 preparer = self.engine.dialect.preparer
229 commonCols = []
228 commonCols = []
230 for modelCol in modelTable.columns:
229 for modelCol in modelTable.columns:
231 if modelCol.name in dbTable.columns:
230 if modelCol.name in dbTable.columns:
232 commonCols.append(modelCol.name)
231 commonCols.append(modelCol.name)
233 commonColsStr = ', '.join(commonCols)
232 commonColsStr = ', '.join(commonCols)
234 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
233 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
235 (tableName, commonColsStr, commonColsStr, tempName)
234 (tableName, commonColsStr, commonColsStr, tempName)
236
235
237 # Move the data in one transaction, so that we don't
236 # Move the data in one transaction, so that we don't
238 # leave the database in a nasty state.
237 # leave the database in a nasty state.
239 connection = self.engine.connect()
238 connection = self.engine.connect()
240 trans = connection.begin()
239 trans = connection.begin()
241 try:
240 try:
242 connection.execute(
241 connection.execute(
243 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
242 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
244 (tempName, modelTable.name))
243 (tempName, modelTable.name))
245 # make sure the drop takes place inside our
244 # make sure the drop takes place inside our
246 # transaction with the bind parameter
245 # transaction with the bind parameter
247 modelTable.drop(bind=connection)
246 modelTable.drop(bind=connection)
248 modelTable.create(bind=connection)
247 modelTable.create(bind=connection)
249 connection.execute(getCopyStatement())
248 connection.execute(getCopyStatement())
250 connection.execute('DROP TABLE %s' % tempName)
249 connection.execute('DROP TABLE %s' % tempName)
251 trans.commit()
250 trans.commit()
252 except:
251 except:
253 trans.rollback()
252 trans.rollback()
254 raise
253 raise
@@ -1,75 +1,75 b''
1 """
1 """
2 A path/directory class.
2 A path/directory class.
3 """
3 """
4
4
5 import os
5 import os
6 import shutil
6 import shutil
7 import logging
7 import logging
8
8
9 from migrate import exceptions
9 from rhodecode.lib.dbmigrate.migrate import exceptions
10 from migrate.versioning.config import *
10 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
11 from migrate.versioning.util import KeyedInstance
11 from rhodecode.lib.dbmigrate.migrate.versioning.util import KeyedInstance
12
12
13
13
14 log = logging.getLogger(__name__)
14 log = logging.getLogger(__name__)
15
15
16 class Pathed(KeyedInstance):
16 class Pathed(KeyedInstance):
17 """
17 """
18 A class associated with a path/directory tree.
18 A class associated with a path/directory tree.
19
19
20 Only one instance of this class may exist for a particular file;
20 Only one instance of this class may exist for a particular file;
21 __new__ will return an existing instance if possible
21 __new__ will return an existing instance if possible
22 """
22 """
23 parent = None
23 parent = None
24
24
25 @classmethod
25 @classmethod
26 def _key(cls, path):
26 def _key(cls, path):
27 return str(path)
27 return str(path)
28
28
29 def __init__(self, path):
29 def __init__(self, path):
30 self.path = path
30 self.path = path
31 if self.__class__.parent is not None:
31 if self.__class__.parent is not None:
32 self._init_parent(path)
32 self._init_parent(path)
33
33
34 def _init_parent(self, path):
34 def _init_parent(self, path):
35 """Try to initialize this object's parent, if it has one"""
35 """Try to initialize this object's parent, if it has one"""
36 parent_path = self.__class__._parent_path(path)
36 parent_path = self.__class__._parent_path(path)
37 self.parent = self.__class__.parent(parent_path)
37 self.parent = self.__class__.parent(parent_path)
38 log.debug("Getting parent %r:%r" % (self.__class__.parent, parent_path))
38 log.debug("Getting parent %r:%r" % (self.__class__.parent, parent_path))
39 self.parent._init_child(path, self)
39 self.parent._init_child(path, self)
40
40
41 def _init_child(self, child, path):
41 def _init_child(self, child, path):
42 """Run when a child of this object is initialized.
42 """Run when a child of this object is initialized.
43
43
44 Parameters: the child object; the path to this object (its
44 Parameters: the child object; the path to this object (its
45 parent)
45 parent)
46 """
46 """
47
47
48 @classmethod
48 @classmethod
49 def _parent_path(cls, path):
49 def _parent_path(cls, path):
50 """
50 """
51 Fetch the path of this object's parent from this object's path.
51 Fetch the path of this object's parent from this object's path.
52 """
52 """
53 # os.path.dirname(), but strip directories like files (like
53 # os.path.dirname(), but strip directories like files (like
54 # unix basename)
54 # unix basename)
55 #
55 #
56 # Treat directories like files...
56 # Treat directories like files...
57 if path[-1] == '/':
57 if path[-1] == '/':
58 path = path[:-1]
58 path = path[:-1]
59 ret = os.path.dirname(path)
59 ret = os.path.dirname(path)
60 return ret
60 return ret
61
61
62 @classmethod
62 @classmethod
63 def require_notfound(cls, path):
63 def require_notfound(cls, path):
64 """Ensures a given path does not already exist"""
64 """Ensures a given path does not already exist"""
65 if os.path.exists(path):
65 if os.path.exists(path):
66 raise exceptions.PathFoundError(path)
66 raise exceptions.PathFoundError(path)
67
67
68 @classmethod
68 @classmethod
69 def require_found(cls, path):
69 def require_found(cls, path):
70 """Ensures a given path already exists"""
70 """Ensures a given path already exists"""
71 if not os.path.exists(path):
71 if not os.path.exists(path):
72 raise exceptions.PathNotFoundError(path)
72 raise exceptions.PathNotFoundError(path)
73
73
74 def __str__(self):
74 def __str__(self):
75 return self.path
75 return self.path
@@ -1,231 +1,231 b''
1 """
1 """
2 SQLAlchemy migrate repository management.
2 SQLAlchemy migrate repository management.
3 """
3 """
4 import os
4 import os
5 import shutil
5 import shutil
6 import string
6 import string
7 import logging
7 import logging
8
8
9 from pkg_resources import resource_filename
9 from pkg_resources import resource_filename
10 from tempita import Template as TempitaTemplate
10 from tempita import Template as TempitaTemplate
11
11
12 from migrate import exceptions
12 from rhodecode.lib.dbmigrate.migrate import exceptions
13 from migrate.versioning import version, pathed, cfgparse
13 from rhodecode.lib.dbmigrate.migrate.versioning import version, pathed, cfgparse
14 from migrate.versioning.template import Template
14 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
15 from migrate.versioning.config import *
15 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
16
16
17
17
18 log = logging.getLogger(__name__)
18 log = logging.getLogger(__name__)
19
19
20 class Changeset(dict):
20 class Changeset(dict):
21 """A collection of changes to be applied to a database.
21 """A collection of changes to be applied to a database.
22
22
23 Changesets are bound to a repository and manage a set of
23 Changesets are bound to a repository and manage a set of
24 scripts from that repository.
24 scripts from that repository.
25
25
26 Behaves like a dict, for the most part. Keys are ordered based on step value.
26 Behaves like a dict, for the most part. Keys are ordered based on step value.
27 """
27 """
28
28
29 def __init__(self, start, *changes, **k):
29 def __init__(self, start, *changes, **k):
30 """
30 """
31 Give a start version; step must be explicitly stated.
31 Give a start version; step must be explicitly stated.
32 """
32 """
33 self.step = k.pop('step', 1)
33 self.step = k.pop('step', 1)
34 self.start = version.VerNum(start)
34 self.start = version.VerNum(start)
35 self.end = self.start
35 self.end = self.start
36 for change in changes:
36 for change in changes:
37 self.add(change)
37 self.add(change)
38
38
39 def __iter__(self):
39 def __iter__(self):
40 return iter(self.items())
40 return iter(self.items())
41
41
42 def keys(self):
42 def keys(self):
43 """
43 """
44 In a series of upgrades x -> y, keys are version x. Sorted.
44 In a series of upgrades x -> y, keys are version x. Sorted.
45 """
45 """
46 ret = super(Changeset, self).keys()
46 ret = super(Changeset, self).keys()
47 # Reverse order if downgrading
47 # Reverse order if downgrading
48 ret.sort(reverse=(self.step < 1))
48 ret.sort(reverse=(self.step < 1))
49 return ret
49 return ret
50
50
51 def values(self):
51 def values(self):
52 return [self[k] for k in self.keys()]
52 return [self[k] for k in self.keys()]
53
53
54 def items(self):
54 def items(self):
55 return zip(self.keys(), self.values())
55 return zip(self.keys(), self.values())
56
56
57 def add(self, change):
57 def add(self, change):
58 """Add new change to changeset"""
58 """Add new change to changeset"""
59 key = self.end
59 key = self.end
60 self.end += self.step
60 self.end += self.step
61 self[key] = change
61 self[key] = change
62
62
63 def run(self, *p, **k):
63 def run(self, *p, **k):
64 """Run the changeset scripts"""
64 """Run the changeset scripts"""
65 for version, script in self:
65 for version, script in self:
66 script.run(*p, **k)
66 script.run(*p, **k)
67
67
68
68
69 class Repository(pathed.Pathed):
69 class Repository(pathed.Pathed):
70 """A project's change script repository"""
70 """A project's change script repository"""
71
71
72 _config = 'migrate.cfg'
72 _config = 'migrate.cfg'
73 _versions = 'versions'
73 _versions = 'versions'
74
74
75 def __init__(self, path):
75 def __init__(self, path):
76 log.debug('Loading repository %s...' % path)
76 log.debug('Loading repository %s...' % path)
77 self.verify(path)
77 self.verify(path)
78 super(Repository, self).__init__(path)
78 super(Repository, self).__init__(path)
79 self.config = cfgparse.Config(os.path.join(self.path, self._config))
79 self.config = cfgparse.Config(os.path.join(self.path, self._config))
80 self.versions = version.Collection(os.path.join(self.path,
80 self.versions = version.Collection(os.path.join(self.path,
81 self._versions))
81 self._versions))
82 log.debug('Repository %s loaded successfully' % path)
82 log.debug('Repository %s loaded successfully' % path)
83 log.debug('Config: %r' % self.config.to_dict())
83 log.debug('Config: %r' % self.config.to_dict())
84
84
85 @classmethod
85 @classmethod
86 def verify(cls, path):
86 def verify(cls, path):
87 """
87 """
88 Ensure the target path is a valid repository.
88 Ensure the target path is a valid repository.
89
89
90 :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
90 :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
91 """
91 """
92 # Ensure the existence of required files
92 # Ensure the existence of required files
93 try:
93 try:
94 cls.require_found(path)
94 cls.require_found(path)
95 cls.require_found(os.path.join(path, cls._config))
95 cls.require_found(os.path.join(path, cls._config))
96 cls.require_found(os.path.join(path, cls._versions))
96 cls.require_found(os.path.join(path, cls._versions))
97 except exceptions.PathNotFoundError, e:
97 except exceptions.PathNotFoundError, e:
98 raise exceptions.InvalidRepositoryError(path)
98 raise exceptions.InvalidRepositoryError(path)
99
99
100 @classmethod
100 @classmethod
101 def prepare_config(cls, tmpl_dir, name, options=None):
101 def prepare_config(cls, tmpl_dir, name, options=None):
102 """
102 """
103 Prepare a project configuration file for a new project.
103 Prepare a project configuration file for a new project.
104
104
105 :param tmpl_dir: Path to Repository template
105 :param tmpl_dir: Path to Repository template
106 :param config_file: Name of the config file in Repository template
106 :param config_file: Name of the config file in Repository template
107 :param name: Repository name
107 :param name: Repository name
108 :type tmpl_dir: string
108 :type tmpl_dir: string
109 :type config_file: string
109 :type config_file: string
110 :type name: string
110 :type name: string
111 :returns: Populated config file
111 :returns: Populated config file
112 """
112 """
113 if options is None:
113 if options is None:
114 options = {}
114 options = {}
115 options.setdefault('version_table', 'migrate_version')
115 options.setdefault('version_table', 'migrate_version')
116 options.setdefault('repository_id', name)
116 options.setdefault('repository_id', name)
117 options.setdefault('required_dbs', [])
117 options.setdefault('required_dbs', [])
118
118
119 tmpl = open(os.path.join(tmpl_dir, cls._config)).read()
119 tmpl = open(os.path.join(tmpl_dir, cls._config)).read()
120 ret = TempitaTemplate(tmpl).substitute(options)
120 ret = TempitaTemplate(tmpl).substitute(options)
121
121
122 # cleanup
122 # cleanup
123 del options['__template_name__']
123 del options['__template_name__']
124
124
125 return ret
125 return ret
126
126
127 @classmethod
127 @classmethod
128 def create(cls, path, name, **opts):
128 def create(cls, path, name, **opts):
129 """Create a repository at a specified path"""
129 """Create a repository at a specified path"""
130 cls.require_notfound(path)
130 cls.require_notfound(path)
131 theme = opts.pop('templates_theme', None)
131 theme = opts.pop('templates_theme', None)
132 t_path = opts.pop('templates_path', None)
132 t_path = opts.pop('templates_path', None)
133
133
134 # Create repository
134 # Create repository
135 tmpl_dir = Template(t_path).get_repository(theme=theme)
135 tmpl_dir = Template(t_path).get_repository(theme=theme)
136 shutil.copytree(tmpl_dir, path)
136 shutil.copytree(tmpl_dir, path)
137
137
138 # Edit config defaults
138 # Edit config defaults
139 config_text = cls.prepare_config(tmpl_dir, name, options=opts)
139 config_text = cls.prepare_config(tmpl_dir, name, options=opts)
140 fd = open(os.path.join(path, cls._config), 'w')
140 fd = open(os.path.join(path, cls._config), 'w')
141 fd.write(config_text)
141 fd.write(config_text)
142 fd.close()
142 fd.close()
143
143
144 opts['repository_name'] = name
144 opts['repository_name'] = name
145
145
146 # Create a management script
146 # Create a management script
147 manager = os.path.join(path, 'manage.py')
147 manager = os.path.join(path, 'manage.py')
148 Repository.create_manage_file(manager, templates_theme=theme,
148 Repository.create_manage_file(manager, templates_theme=theme,
149 templates_path=t_path, **opts)
149 templates_path=t_path, **opts)
150
150
151 return cls(path)
151 return cls(path)
152
152
153 def create_script(self, description, **k):
153 def create_script(self, description, **k):
154 """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
154 """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
155 self.versions.create_new_python_version(description, **k)
155 self.versions.create_new_python_version(description, **k)
156
156
157 def create_script_sql(self, database, **k):
157 def create_script_sql(self, database, **k):
158 """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
158 """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
159 self.versions.create_new_sql_version(database, **k)
159 self.versions.create_new_sql_version(database, **k)
160
160
161 @property
161 @property
162 def latest(self):
162 def latest(self):
163 """API to :attr:`migrate.versioning.version.Collection.latest`"""
163 """API to :attr:`migrate.versioning.version.Collection.latest`"""
164 return self.versions.latest
164 return self.versions.latest
165
165
166 @property
166 @property
167 def version_table(self):
167 def version_table(self):
168 """Returns version_table name specified in config"""
168 """Returns version_table name specified in config"""
169 return self.config.get('db_settings', 'version_table')
169 return self.config.get('db_settings', 'version_table')
170
170
171 @property
171 @property
172 def id(self):
172 def id(self):
173 """Returns repository id specified in config"""
173 """Returns repository id specified in config"""
174 return self.config.get('db_settings', 'repository_id')
174 return self.config.get('db_settings', 'repository_id')
175
175
176 def version(self, *p, **k):
176 def version(self, *p, **k):
177 """API to :attr:`migrate.versioning.version.Collection.version`"""
177 """API to :attr:`migrate.versioning.version.Collection.version`"""
178 return self.versions.version(*p, **k)
178 return self.versions.version(*p, **k)
179
179
180 @classmethod
180 @classmethod
181 def clear(cls):
181 def clear(cls):
182 # TODO: deletes repo
182 # TODO: deletes repo
183 super(Repository, cls).clear()
183 super(Repository, cls).clear()
184 version.Collection.clear()
184 version.Collection.clear()
185
185
186 def changeset(self, database, start, end=None):
186 def changeset(self, database, start, end=None):
187 """Create a changeset to migrate this database from ver. start to end/latest.
187 """Create a changeset to migrate this database from ver. start to end/latest.
188
188
189 :param database: name of database to generate changeset
189 :param database: name of database to generate changeset
190 :param start: version to start at
190 :param start: version to start at
191 :param end: version to end at (latest if None given)
191 :param end: version to end at (latest if None given)
192 :type database: string
192 :type database: string
193 :type start: int
193 :type start: int
194 :type end: int
194 :type end: int
195 :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
195 :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
196 """
196 """
197 start = version.VerNum(start)
197 start = version.VerNum(start)
198
198
199 if end is None:
199 if end is None:
200 end = self.latest
200 end = self.latest
201 else:
201 else:
202 end = version.VerNum(end)
202 end = version.VerNum(end)
203
203
204 if start <= end:
204 if start <= end:
205 step = 1
205 step = 1
206 range_mod = 1
206 range_mod = 1
207 op = 'upgrade'
207 op = 'upgrade'
208 else:
208 else:
209 step = -1
209 step = -1
210 range_mod = 0
210 range_mod = 0
211 op = 'downgrade'
211 op = 'downgrade'
212
212
213 versions = range(start + range_mod, end + range_mod, step)
213 versions = range(start + range_mod, end + range_mod, step)
214 changes = [self.version(v).script(database, op) for v in versions]
214 changes = [self.version(v).script(database, op) for v in versions]
215 ret = Changeset(start, step=step, *changes)
215 ret = Changeset(start, step=step, *changes)
216 return ret
216 return ret
217
217
218 @classmethod
218 @classmethod
219 def create_manage_file(cls, file_, **opts):
219 def create_manage_file(cls, file_, **opts):
220 """Create a project management script (manage.py)
220 """Create a project management script (manage.py)
221
221
222 :param file_: Destination file to be written
222 :param file_: Destination file to be written
223 :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
223 :param opts: Options that are passed to :func:`migrate.versioning.shell.main`
224 """
224 """
225 mng_file = Template(opts.pop('templates_path', None))\
225 mng_file = Template(opts.pop('templates_path', None))\
226 .get_manage(theme=opts.pop('templates_theme', None))
226 .get_manage(theme=opts.pop('templates_theme', None))
227
227
228 tmpl = open(mng_file).read()
228 tmpl = open(mng_file).read()
229 fd = open(file_, 'w')
229 fd = open(file_, 'w')
230 fd.write(TempitaTemplate(tmpl).substitute(opts))
230 fd.write(TempitaTemplate(tmpl).substitute(opts))
231 fd.close()
231 fd.close()
@@ -1,213 +1,213 b''
1 """
1 """
2 Database schema version management.
2 Database schema version management.
3 """
3 """
4 import sys
4 import sys
5 import logging
5 import logging
6
6
7 from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
7 from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
8 create_engine)
8 create_engine)
9 from sqlalchemy.sql import and_
9 from sqlalchemy.sql import and_
10 from sqlalchemy import exceptions as sa_exceptions
10 from sqlalchemy import exceptions as sa_exceptions
11 from sqlalchemy.sql import bindparam
11 from sqlalchemy.sql import bindparam
12
12
13 from migrate import exceptions
13 from rhodecode.lib.dbmigrate.migrate import exceptions
14 from migrate.versioning import genmodel, schemadiff
14 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
15 from migrate.versioning.repository import Repository
15 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
16 from migrate.versioning.util import load_model
16 from rhodecode.lib.dbmigrate.migrate.versioning.util import load_model
17 from migrate.versioning.version import VerNum
17 from rhodecode.lib.dbmigrate.migrate.versioning.version import VerNum
18
18
19
19
20 log = logging.getLogger(__name__)
20 log = logging.getLogger(__name__)
21
21
22 class ControlledSchema(object):
22 class ControlledSchema(object):
23 """A database under version control"""
23 """A database under version control"""
24
24
25 def __init__(self, engine, repository):
25 def __init__(self, engine, repository):
26 if isinstance(repository, basestring):
26 if isinstance(repository, basestring):
27 repository = Repository(repository)
27 repository = Repository(repository)
28 self.engine = engine
28 self.engine = engine
29 self.repository = repository
29 self.repository = repository
30 self.meta = MetaData(engine)
30 self.meta = MetaData(engine)
31 self.load()
31 self.load()
32
32
33 def __eq__(self, other):
33 def __eq__(self, other):
34 """Compare two schemas by repositories and versions"""
34 """Compare two schemas by repositories and versions"""
35 return (self.repository is other.repository \
35 return (self.repository is other.repository \
36 and self.version == other.version)
36 and self.version == other.version)
37
37
38 def load(self):
38 def load(self):
39 """Load controlled schema version info from DB"""
39 """Load controlled schema version info from DB"""
40 tname = self.repository.version_table
40 tname = self.repository.version_table
41 try:
41 try:
42 if not hasattr(self, 'table') or self.table is None:
42 if not hasattr(self, 'table') or self.table is None:
43 self.table = Table(tname, self.meta, autoload=True)
43 self.table = Table(tname, self.meta, autoload=True)
44
44
45 result = self.engine.execute(self.table.select(
45 result = self.engine.execute(self.table.select(
46 self.table.c.repository_id == str(self.repository.id)))
46 self.table.c.repository_id == str(self.repository.id)))
47
47
48 data = list(result)[0]
48 data = list(result)[0]
49 except:
49 except:
50 cls, exc, tb = sys.exc_info()
50 cls, exc, tb = sys.exc_info()
51 raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
51 raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
52
52
53 self.version = data['version']
53 self.version = data['version']
54 return data
54 return data
55
55
56 def drop(self):
56 def drop(self):
57 """
57 """
58 Remove version control from a database.
58 Remove version control from a database.
59 """
59 """
60 try:
60 try:
61 self.table.drop()
61 self.table.drop()
62 except (sa_exceptions.SQLError):
62 except (sa_exceptions.SQLError):
63 raise exceptions.DatabaseNotControlledError(str(self.table))
63 raise exceptions.DatabaseNotControlledError(str(self.table))
64
64
65 def changeset(self, version=None):
65 def changeset(self, version=None):
66 """API to Changeset creation.
66 """API to Changeset creation.
67
67
68 Uses self.version for start version and engine.name
68 Uses self.version for start version and engine.name
69 to get database name.
69 to get database name.
70 """
70 """
71 database = self.engine.name
71 database = self.engine.name
72 start_ver = self.version
72 start_ver = self.version
73 changeset = self.repository.changeset(database, start_ver, version)
73 changeset = self.repository.changeset(database, start_ver, version)
74 return changeset
74 return changeset
75
75
76 def runchange(self, ver, change, step):
76 def runchange(self, ver, change, step):
77 startver = ver
77 startver = ver
78 endver = ver + step
78 endver = ver + step
79 # Current database version must be correct! Don't run if corrupt!
79 # Current database version must be correct! Don't run if corrupt!
80 if self.version != startver:
80 if self.version != startver:
81 raise exceptions.InvalidVersionError("%s is not %s" % \
81 raise exceptions.InvalidVersionError("%s is not %s" % \
82 (self.version, startver))
82 (self.version, startver))
83 # Run the change
83 # Run the change
84 change.run(self.engine, step)
84 change.run(self.engine, step)
85
85
86 # Update/refresh database version
86 # Update/refresh database version
87 self.update_repository_table(startver, endver)
87 self.update_repository_table(startver, endver)
88 self.load()
88 self.load()
89
89
90 def update_repository_table(self, startver, endver):
90 def update_repository_table(self, startver, endver):
91 """Update version_table with new information"""
91 """Update version_table with new information"""
92 update = self.table.update(and_(self.table.c.version == int(startver),
92 update = self.table.update(and_(self.table.c.version == int(startver),
93 self.table.c.repository_id == str(self.repository.id)))
93 self.table.c.repository_id == str(self.repository.id)))
94 self.engine.execute(update, version=int(endver))
94 self.engine.execute(update, version=int(endver))
95
95
96 def upgrade(self, version=None):
96 def upgrade(self, version=None):
97 """
97 """
98 Upgrade (or downgrade) to a specified version, or latest version.
98 Upgrade (or downgrade) to a specified version, or latest version.
99 """
99 """
100 changeset = self.changeset(version)
100 changeset = self.changeset(version)
101 for ver, change in changeset:
101 for ver, change in changeset:
102 self.runchange(ver, change, changeset.step)
102 self.runchange(ver, change, changeset.step)
103
103
104 def update_db_from_model(self, model):
104 def update_db_from_model(self, model):
105 """
105 """
106 Modify the database to match the structure of the current Python model.
106 Modify the database to match the structure of the current Python model.
107 """
107 """
108 model = load_model(model)
108 model = load_model(model)
109
109
110 diff = schemadiff.getDiffOfModelAgainstDatabase(
110 diff = schemadiff.getDiffOfModelAgainstDatabase(
111 model, self.engine, excludeTables=[self.repository.version_table]
111 model, self.engine, excludeTables=[self.repository.version_table]
112 )
112 )
113 genmodel.ModelGenerator(diff,self.engine).applyModel()
113 genmodel.ModelGenerator(diff,self.engine).applyModel()
114
114
115 self.update_repository_table(self.version, int(self.repository.latest))
115 self.update_repository_table(self.version, int(self.repository.latest))
116
116
117 self.load()
117 self.load()
118
118
119 @classmethod
119 @classmethod
120 def create(cls, engine, repository, version=None):
120 def create(cls, engine, repository, version=None):
121 """
121 """
122 Declare a database to be under a repository's version control.
122 Declare a database to be under a repository's version control.
123
123
124 :raises: :exc:`DatabaseAlreadyControlledError`
124 :raises: :exc:`DatabaseAlreadyControlledError`
125 :returns: :class:`ControlledSchema`
125 :returns: :class:`ControlledSchema`
126 """
126 """
127 # Confirm that the version # is valid: positive, integer,
127 # Confirm that the version # is valid: positive, integer,
128 # exists in repos
128 # exists in repos
129 if isinstance(repository, basestring):
129 if isinstance(repository, basestring):
130 repository = Repository(repository)
130 repository = Repository(repository)
131 version = cls._validate_version(repository, version)
131 version = cls._validate_version(repository, version)
132 table = cls._create_table_version(engine, repository, version)
132 table = cls._create_table_version(engine, repository, version)
133 # TODO: history table
133 # TODO: history table
134 # Load repository information and return
134 # Load repository information and return
135 return cls(engine, repository)
135 return cls(engine, repository)
136
136
137 @classmethod
137 @classmethod
138 def _validate_version(cls, repository, version):
138 def _validate_version(cls, repository, version):
139 """
139 """
140 Ensures this is a valid version number for this repository.
140 Ensures this is a valid version number for this repository.
141
141
142 :raises: :exc:`InvalidVersionError` if invalid
142 :raises: :exc:`InvalidVersionError` if invalid
143 :return: valid version number
143 :return: valid version number
144 """
144 """
145 if version is None:
145 if version is None:
146 version = 0
146 version = 0
147 try:
147 try:
148 version = VerNum(version) # raises valueerror
148 version = VerNum(version) # raises valueerror
149 if version < 0 or version > repository.latest:
149 if version < 0 or version > repository.latest:
150 raise ValueError()
150 raise ValueError()
151 except ValueError:
151 except ValueError:
152 raise exceptions.InvalidVersionError(version)
152 raise exceptions.InvalidVersionError(version)
153 return version
153 return version
154
154
155 @classmethod
155 @classmethod
156 def _create_table_version(cls, engine, repository, version):
156 def _create_table_version(cls, engine, repository, version):
157 """
157 """
158 Creates the versioning table in a database.
158 Creates the versioning table in a database.
159
159
160 :raises: :exc:`DatabaseAlreadyControlledError`
160 :raises: :exc:`DatabaseAlreadyControlledError`
161 """
161 """
162 # Create tables
162 # Create tables
163 tname = repository.version_table
163 tname = repository.version_table
164 meta = MetaData(engine)
164 meta = MetaData(engine)
165
165
166 table = Table(
166 table = Table(
167 tname, meta,
167 tname, meta,
168 Column('repository_id', String(250), primary_key=True),
168 Column('repository_id', String(250), primary_key=True),
169 Column('repository_path', Text),
169 Column('repository_path', Text),
170 Column('version', Integer), )
170 Column('version', Integer), )
171
171
172 # there can be multiple repositories/schemas in the same db
172 # there can be multiple repositories/schemas in the same db
173 if not table.exists():
173 if not table.exists():
174 table.create()
174 table.create()
175
175
176 # test for existing repository_id
176 # test for existing repository_id
177 s = table.select(table.c.repository_id == bindparam("repository_id"))
177 s = table.select(table.c.repository_id == bindparam("repository_id"))
178 result = engine.execute(s, repository_id=repository.id)
178 result = engine.execute(s, repository_id=repository.id)
179 if result.fetchone():
179 if result.fetchone():
180 raise exceptions.DatabaseAlreadyControlledError
180 raise exceptions.DatabaseAlreadyControlledError
181
181
182 # Insert data
182 # Insert data
183 engine.execute(table.insert().values(
183 engine.execute(table.insert().values(
184 repository_id=repository.id,
184 repository_id=repository.id,
185 repository_path=repository.path,
185 repository_path=repository.path,
186 version=int(version)))
186 version=int(version)))
187 return table
187 return table
188
188
189 @classmethod
189 @classmethod
190 def compare_model_to_db(cls, engine, model, repository):
190 def compare_model_to_db(cls, engine, model, repository):
191 """
191 """
192 Compare the current model against the current database.
192 Compare the current model against the current database.
193 """
193 """
194 if isinstance(repository, basestring):
194 if isinstance(repository, basestring):
195 repository = Repository(repository)
195 repository = Repository(repository)
196 model = load_model(model)
196 model = load_model(model)
197
197
198 diff = schemadiff.getDiffOfModelAgainstDatabase(
198 diff = schemadiff.getDiffOfModelAgainstDatabase(
199 model, engine, excludeTables=[repository.version_table])
199 model, engine, excludeTables=[repository.version_table])
200 return diff
200 return diff
201
201
202 @classmethod
202 @classmethod
203 def create_model(cls, engine, repository, declarative=False):
203 def create_model(cls, engine, repository, declarative=False):
204 """
204 """
205 Dump the current database as a Python model.
205 Dump the current database as a Python model.
206 """
206 """
207 if isinstance(repository, basestring):
207 if isinstance(repository, basestring):
208 repository = Repository(repository)
208 repository = Repository(repository)
209
209
210 diff = schemadiff.getDiffOfModelAgainstDatabase(
210 diff = schemadiff.getDiffOfModelAgainstDatabase(
211 MetaData(), engine, excludeTables=[repository.version_table]
211 MetaData(), engine, excludeTables=[repository.version_table]
212 )
212 )
213 return genmodel.ModelGenerator(diff, engine, declarative).toPython()
213 return genmodel.ModelGenerator(diff, engine, declarative).toPython()
@@ -1,285 +1,285 b''
1 """
1 """
2 Schema differencing support.
2 Schema differencing support.
3 """
3 """
4
4
5 import logging
5 import logging
6 import sqlalchemy
6 import sqlalchemy
7
7
8 from migrate.changeset import SQLA_06
8 from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06
9 from sqlalchemy.types import Float
9 from sqlalchemy.types import Float
10
10
11 log = logging.getLogger(__name__)
11 log = logging.getLogger(__name__)
12
12
13 def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None):
13 def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None):
14 """
14 """
15 Return differences of model against database.
15 Return differences of model against database.
16
16
17 :return: object which will evaluate to :keyword:`True` if there \
17 :return: object which will evaluate to :keyword:`True` if there \
18 are differences else :keyword:`False`.
18 are differences else :keyword:`False`.
19 """
19 """
20 return SchemaDiff(metadata,
20 return SchemaDiff(metadata,
21 sqlalchemy.MetaData(engine, reflect=True),
21 sqlalchemy.MetaData(engine, reflect=True),
22 labelA='model',
22 labelA='model',
23 labelB='database',
23 labelB='database',
24 excludeTables=excludeTables)
24 excludeTables=excludeTables)
25
25
26
26
27 def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None):
27 def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None):
28 """
28 """
29 Return differences of model against another model.
29 Return differences of model against another model.
30
30
31 :return: object which will evaluate to :keyword:`True` if there \
31 :return: object which will evaluate to :keyword:`True` if there \
32 are differences else :keyword:`False`.
32 are differences else :keyword:`False`.
33 """
33 """
34 return SchemaDiff(metadataA, metadataB, excludeTables)
34 return SchemaDiff(metadataA, metadataB, excludeTables)
35
35
36
36
37 class ColDiff(object):
37 class ColDiff(object):
38 """
38 """
39 Container for differences in one :class:`~sqlalchemy.schema.Column`
39 Container for differences in one :class:`~sqlalchemy.schema.Column`
40 between two :class:`~sqlalchemy.schema.Table` instances, ``A``
40 between two :class:`~sqlalchemy.schema.Table` instances, ``A``
41 and ``B``.
41 and ``B``.
42
42
43 .. attribute:: col_A
43 .. attribute:: col_A
44
44
45 The :class:`~sqlalchemy.schema.Column` object for A.
45 The :class:`~sqlalchemy.schema.Column` object for A.
46
46
47 .. attribute:: col_B
47 .. attribute:: col_B
48
48
49 The :class:`~sqlalchemy.schema.Column` object for B.
49 The :class:`~sqlalchemy.schema.Column` object for B.
50
50
51 .. attribute:: type_A
51 .. attribute:: type_A
52
52
53 The most generic type of the :class:`~sqlalchemy.schema.Column`
53 The most generic type of the :class:`~sqlalchemy.schema.Column`
54 object in A.
54 object in A.
55
55
56 .. attribute:: type_B
56 .. attribute:: type_B
57
57
58 The most generic type of the :class:`~sqlalchemy.schema.Column`
58 The most generic type of the :class:`~sqlalchemy.schema.Column`
59 object in A.
59 object in A.
60
60
61 """
61 """
62
62
63 diff = False
63 diff = False
64
64
65 def __init__(self,col_A,col_B):
65 def __init__(self,col_A,col_B):
66 self.col_A = col_A
66 self.col_A = col_A
67 self.col_B = col_B
67 self.col_B = col_B
68
68
69 self.type_A = col_A.type
69 self.type_A = col_A.type
70 self.type_B = col_B.type
70 self.type_B = col_B.type
71
71
72 self.affinity_A = self.type_A._type_affinity
72 self.affinity_A = self.type_A._type_affinity
73 self.affinity_B = self.type_B._type_affinity
73 self.affinity_B = self.type_B._type_affinity
74
74
75 if self.affinity_A is not self.affinity_B:
75 if self.affinity_A is not self.affinity_B:
76 self.diff = True
76 self.diff = True
77 return
77 return
78
78
79 if isinstance(self.type_A,Float) or isinstance(self.type_B,Float):
79 if isinstance(self.type_A,Float) or isinstance(self.type_B,Float):
80 if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)):
80 if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)):
81 self.diff=True
81 self.diff=True
82 return
82 return
83
83
84 for attr in ('precision','scale','length'):
84 for attr in ('precision','scale','length'):
85 A = getattr(self.type_A,attr,None)
85 A = getattr(self.type_A,attr,None)
86 B = getattr(self.type_B,attr,None)
86 B = getattr(self.type_B,attr,None)
87 if not (A is None or B is None) and A!=B:
87 if not (A is None or B is None) and A!=B:
88 self.diff=True
88 self.diff=True
89 return
89 return
90
90
91 def __nonzero__(self):
91 def __nonzero__(self):
92 return self.diff
92 return self.diff
93
93
94 class TableDiff(object):
94 class TableDiff(object):
95 """
95 """
96 Container for differences in one :class:`~sqlalchemy.schema.Table`
96 Container for differences in one :class:`~sqlalchemy.schema.Table`
97 between two :class:`~sqlalchemy.schema.MetaData` instances, ``A``
97 between two :class:`~sqlalchemy.schema.MetaData` instances, ``A``
98 and ``B``.
98 and ``B``.
99
99
100 .. attribute:: columns_missing_from_A
100 .. attribute:: columns_missing_from_A
101
101
102 A sequence of column names that were found in B but weren't in
102 A sequence of column names that were found in B but weren't in
103 A.
103 A.
104
104
105 .. attribute:: columns_missing_from_B
105 .. attribute:: columns_missing_from_B
106
106
107 A sequence of column names that were found in A but weren't in
107 A sequence of column names that were found in A but weren't in
108 B.
108 B.
109
109
110 .. attribute:: columns_different
110 .. attribute:: columns_different
111
111
112 A dictionary containing information about columns that were
112 A dictionary containing information about columns that were
113 found to be different.
113 found to be different.
114 It maps column names to a :class:`ColDiff` objects describing the
114 It maps column names to a :class:`ColDiff` objects describing the
115 differences found.
115 differences found.
116 """
116 """
117 __slots__ = (
117 __slots__ = (
118 'columns_missing_from_A',
118 'columns_missing_from_A',
119 'columns_missing_from_B',
119 'columns_missing_from_B',
120 'columns_different',
120 'columns_different',
121 )
121 )
122
122
123 def __nonzero__(self):
123 def __nonzero__(self):
124 return bool(
124 return bool(
125 self.columns_missing_from_A or
125 self.columns_missing_from_A or
126 self.columns_missing_from_B or
126 self.columns_missing_from_B or
127 self.columns_different
127 self.columns_different
128 )
128 )
129
129
130 class SchemaDiff(object):
130 class SchemaDiff(object):
131 """
131 """
132 Compute the difference between two :class:`~sqlalchemy.schema.MetaData`
132 Compute the difference between two :class:`~sqlalchemy.schema.MetaData`
133 objects.
133 objects.
134
134
135 The string representation of a :class:`SchemaDiff` will summarise
135 The string representation of a :class:`SchemaDiff` will summarise
136 the changes found between the two
136 the changes found between the two
137 :class:`~sqlalchemy.schema.MetaData` objects.
137 :class:`~sqlalchemy.schema.MetaData` objects.
138
138
139 The length of a :class:`SchemaDiff` will give the number of
139 The length of a :class:`SchemaDiff` will give the number of
140 changes found, enabling it to be used much like a boolean in
140 changes found, enabling it to be used much like a boolean in
141 expressions.
141 expressions.
142
142
143 :param metadataA:
143 :param metadataA:
144 First :class:`~sqlalchemy.schema.MetaData` to compare.
144 First :class:`~sqlalchemy.schema.MetaData` to compare.
145
145
146 :param metadataB:
146 :param metadataB:
147 Second :class:`~sqlalchemy.schema.MetaData` to compare.
147 Second :class:`~sqlalchemy.schema.MetaData` to compare.
148
148
149 :param labelA:
149 :param labelA:
150 The label to use in messages about the first
150 The label to use in messages about the first
151 :class:`~sqlalchemy.schema.MetaData`.
151 :class:`~sqlalchemy.schema.MetaData`.
152
152
153 :param labelB:
153 :param labelB:
154 The label to use in messages about the second
154 The label to use in messages about the second
155 :class:`~sqlalchemy.schema.MetaData`.
155 :class:`~sqlalchemy.schema.MetaData`.
156
156
157 :param excludeTables:
157 :param excludeTables:
158 A sequence of table names to exclude.
158 A sequence of table names to exclude.
159
159
160 .. attribute:: tables_missing_from_A
160 .. attribute:: tables_missing_from_A
161
161
162 A sequence of table names that were found in B but weren't in
162 A sequence of table names that were found in B but weren't in
163 A.
163 A.
164
164
165 .. attribute:: tables_missing_from_B
165 .. attribute:: tables_missing_from_B
166
166
167 A sequence of table names that were found in A but weren't in
167 A sequence of table names that were found in A but weren't in
168 B.
168 B.
169
169
170 .. attribute:: tables_different
170 .. attribute:: tables_different
171
171
172 A dictionary containing information about tables that were found
172 A dictionary containing information about tables that were found
173 to be different.
173 to be different.
174 It maps table names to a :class:`TableDiff` objects describing the
174 It maps table names to a :class:`TableDiff` objects describing the
175 differences found.
175 differences found.
176 """
176 """
177
177
178 def __init__(self,
178 def __init__(self,
179 metadataA, metadataB,
179 metadataA, metadataB,
180 labelA='metadataA',
180 labelA='metadataA',
181 labelB='metadataB',
181 labelB='metadataB',
182 excludeTables=None):
182 excludeTables=None):
183
183
184 self.metadataA, self.metadataB = metadataA, metadataB
184 self.metadataA, self.metadataB = metadataA, metadataB
185 self.labelA, self.labelB = labelA, labelB
185 self.labelA, self.labelB = labelA, labelB
186 self.label_width = max(len(labelA),len(labelB))
186 self.label_width = max(len(labelA),len(labelB))
187 excludeTables = set(excludeTables or [])
187 excludeTables = set(excludeTables or [])
188
188
189 A_table_names = set(metadataA.tables.keys())
189 A_table_names = set(metadataA.tables.keys())
190 B_table_names = set(metadataB.tables.keys())
190 B_table_names = set(metadataB.tables.keys())
191
191
192 self.tables_missing_from_A = sorted(
192 self.tables_missing_from_A = sorted(
193 B_table_names - A_table_names - excludeTables
193 B_table_names - A_table_names - excludeTables
194 )
194 )
195 self.tables_missing_from_B = sorted(
195 self.tables_missing_from_B = sorted(
196 A_table_names - B_table_names - excludeTables
196 A_table_names - B_table_names - excludeTables
197 )
197 )
198
198
199 self.tables_different = {}
199 self.tables_different = {}
200 for table_name in A_table_names.intersection(B_table_names):
200 for table_name in A_table_names.intersection(B_table_names):
201
201
202 td = TableDiff()
202 td = TableDiff()
203
203
204 A_table = metadataA.tables[table_name]
204 A_table = metadataA.tables[table_name]
205 B_table = metadataB.tables[table_name]
205 B_table = metadataB.tables[table_name]
206
206
207 A_column_names = set(A_table.columns.keys())
207 A_column_names = set(A_table.columns.keys())
208 B_column_names = set(B_table.columns.keys())
208 B_column_names = set(B_table.columns.keys())
209
209
210 td.columns_missing_from_A = sorted(
210 td.columns_missing_from_A = sorted(
211 B_column_names - A_column_names
211 B_column_names - A_column_names
212 )
212 )
213
213
214 td.columns_missing_from_B = sorted(
214 td.columns_missing_from_B = sorted(
215 A_column_names - B_column_names
215 A_column_names - B_column_names
216 )
216 )
217
217
218 td.columns_different = {}
218 td.columns_different = {}
219
219
220 for col_name in A_column_names.intersection(B_column_names):
220 for col_name in A_column_names.intersection(B_column_names):
221
221
222 cd = ColDiff(
222 cd = ColDiff(
223 A_table.columns.get(col_name),
223 A_table.columns.get(col_name),
224 B_table.columns.get(col_name)
224 B_table.columns.get(col_name)
225 )
225 )
226
226
227 if cd:
227 if cd:
228 td.columns_different[col_name]=cd
228 td.columns_different[col_name]=cd
229
229
230 # XXX - index and constraint differences should
230 # XXX - index and constraint differences should
231 # be checked for here
231 # be checked for here
232
232
233 if td:
233 if td:
234 self.tables_different[table_name]=td
234 self.tables_different[table_name]=td
235
235
236 def __str__(self):
236 def __str__(self):
237 ''' Summarize differences. '''
237 ''' Summarize differences. '''
238 out = []
238 out = []
239 column_template =' %%%is: %%r' % self.label_width
239 column_template =' %%%is: %%r' % self.label_width
240
240
241 for names,label in (
241 for names,label in (
242 (self.tables_missing_from_A,self.labelA),
242 (self.tables_missing_from_A,self.labelA),
243 (self.tables_missing_from_B,self.labelB),
243 (self.tables_missing_from_B,self.labelB),
244 ):
244 ):
245 if names:
245 if names:
246 out.append(
246 out.append(
247 ' tables missing from %s: %s' % (
247 ' tables missing from %s: %s' % (
248 label,', '.join(sorted(names))
248 label,', '.join(sorted(names))
249 )
249 )
250 )
250 )
251
251
252 for name,td in sorted(self.tables_different.items()):
252 for name,td in sorted(self.tables_different.items()):
253 out.append(
253 out.append(
254 ' table with differences: %s' % name
254 ' table with differences: %s' % name
255 )
255 )
256 for names,label in (
256 for names,label in (
257 (td.columns_missing_from_A,self.labelA),
257 (td.columns_missing_from_A,self.labelA),
258 (td.columns_missing_from_B,self.labelB),
258 (td.columns_missing_from_B,self.labelB),
259 ):
259 ):
260 if names:
260 if names:
261 out.append(
261 out.append(
262 ' %s missing these columns: %s' % (
262 ' %s missing these columns: %s' % (
263 label,', '.join(sorted(names))
263 label,', '.join(sorted(names))
264 )
264 )
265 )
265 )
266 for name,cd in td.columns_different.items():
266 for name,cd in td.columns_different.items():
267 out.append(' column with differences: %s' % name)
267 out.append(' column with differences: %s' % name)
268 out.append(column_template % (self.labelA,cd.col_A))
268 out.append(column_template % (self.labelA,cd.col_A))
269 out.append(column_template % (self.labelB,cd.col_B))
269 out.append(column_template % (self.labelB,cd.col_B))
270
270
271 if out:
271 if out:
272 out.insert(0, 'Schema diffs:')
272 out.insert(0, 'Schema diffs:')
273 return '\n'.join(out)
273 return '\n'.join(out)
274 else:
274 else:
275 return 'No schema diffs'
275 return 'No schema diffs'
276
276
277 def __len__(self):
277 def __len__(self):
278 """
278 """
279 Used in bool evaluation, return of 0 means no diffs.
279 Used in bool evaluation, return of 0 means no diffs.
280 """
280 """
281 return (
281 return (
282 len(self.tables_missing_from_A) +
282 len(self.tables_missing_from_A) +
283 len(self.tables_missing_from_B) +
283 len(self.tables_missing_from_B) +
284 len(self.tables_different)
284 len(self.tables_different)
285 )
285 )
@@ -1,6 +1,6 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3
3
4 from migrate.versioning.script.base import BaseScript
4 from rhodecode.lib.dbmigrate.migrate.versioning.script.base import BaseScript
5 from migrate.versioning.script.py import PythonScript
5 from rhodecode.lib.dbmigrate.migrate.versioning.script.py import PythonScript
6 from migrate.versioning.script.sql import SqlScript
6 from rhodecode.lib.dbmigrate.migrate.versioning.script.sql import SqlScript
@@ -1,57 +1,57 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3 import logging
3 import logging
4
4
5 from migrate import exceptions
5 from rhodecode.lib.dbmigrate.migrate import exceptions
6 from migrate.versioning.config import operations
6 from rhodecode.lib.dbmigrate.migrate.versioning.config import operations
7 from migrate.versioning import pathed
7 from rhodecode.lib.dbmigrate.migrate.versioning import pathed
8
8
9
9
10 log = logging.getLogger(__name__)
10 log = logging.getLogger(__name__)
11
11
12 class BaseScript(pathed.Pathed):
12 class BaseScript(pathed.Pathed):
13 """Base class for other types of scripts.
13 """Base class for other types of scripts.
14 All scripts have the following properties:
14 All scripts have the following properties:
15
15
16 source (script.source())
16 source (script.source())
17 The source code of the script
17 The source code of the script
18 version (script.version())
18 version (script.version())
19 The version number of the script
19 The version number of the script
20 operations (script.operations())
20 operations (script.operations())
21 The operations defined by the script: upgrade(), downgrade() or both.
21 The operations defined by the script: upgrade(), downgrade() or both.
22 Returns a tuple of operations.
22 Returns a tuple of operations.
23 Can also check for an operation with ex. script.operation(Script.ops.up)
23 Can also check for an operation with ex. script.operation(Script.ops.up)
24 """ # TODO: sphinxfy this and implement it correctly
24 """ # TODO: sphinxfy this and implement it correctly
25
25
26 def __init__(self, path):
26 def __init__(self, path):
27 log.debug('Loading script %s...' % path)
27 log.debug('Loading script %s...' % path)
28 self.verify(path)
28 self.verify(path)
29 super(BaseScript, self).__init__(path)
29 super(BaseScript, self).__init__(path)
30 log.debug('Script %s loaded successfully' % path)
30 log.debug('Script %s loaded successfully' % path)
31
31
32 @classmethod
32 @classmethod
33 def verify(cls, path):
33 def verify(cls, path):
34 """Ensure this is a valid script
34 """Ensure this is a valid script
35 This version simply ensures the script file's existence
35 This version simply ensures the script file's existence
36
36
37 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
37 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
38 """
38 """
39 try:
39 try:
40 cls.require_found(path)
40 cls.require_found(path)
41 except:
41 except:
42 raise exceptions.InvalidScriptError(path)
42 raise exceptions.InvalidScriptError(path)
43
43
44 def source(self):
44 def source(self):
45 """:returns: source code of the script.
45 """:returns: source code of the script.
46 :rtype: string
46 :rtype: string
47 """
47 """
48 fd = open(self.path)
48 fd = open(self.path)
49 ret = fd.read()
49 ret = fd.read()
50 fd.close()
50 fd.close()
51 return ret
51 return ret
52
52
53 def run(self, engine):
53 def run(self, engine):
54 """Core of each BaseScript subclass.
54 """Core of each BaseScript subclass.
55 This method executes the script.
55 This method executes the script.
56 """
56 """
57 raise NotImplementedError()
57 raise NotImplementedError()
@@ -1,159 +1,159 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3
3
4 import shutil
4 import shutil
5 import warnings
5 import warnings
6 import logging
6 import logging
7 from StringIO import StringIO
7 from StringIO import StringIO
8
8
9 import migrate
9 from rhodecode.lib.dbmigrate import migrate
10 from migrate.versioning import genmodel, schemadiff
10 from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff
11 from migrate.versioning.config import operations
11 from rhodecode.lib.dbmigrate.migrate.versioning.config import operations
12 from migrate.versioning.template import Template
12 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
13 from migrate.versioning.script import base
13 from rhodecode.lib.dbmigrate.migrate.versioning.script import base
14 from migrate.versioning.util import import_path, load_model, with_engine
14 from rhodecode.lib.dbmigrate.migrate.versioning.util import import_path, load_model, with_engine
15 from migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
15 from rhodecode.lib.dbmigrate.migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
16
16
17 log = logging.getLogger(__name__)
17 log = logging.getLogger(__name__)
18 __all__ = ['PythonScript']
18 __all__ = ['PythonScript']
19
19
20
20
21 class PythonScript(base.BaseScript):
21 class PythonScript(base.BaseScript):
22 """Base for Python scripts"""
22 """Base for Python scripts"""
23
23
24 @classmethod
24 @classmethod
25 def create(cls, path, **opts):
25 def create(cls, path, **opts):
26 """Create an empty migration script at specified path
26 """Create an empty migration script at specified path
27
27
28 :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
28 :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
29 cls.require_notfound(path)
29 cls.require_notfound(path)
30
30
31 src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
31 src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
32 shutil.copy(src, path)
32 shutil.copy(src, path)
33
33
34 return cls(path)
34 return cls(path)
35
35
36 @classmethod
36 @classmethod
37 def make_update_script_for_model(cls, engine, oldmodel,
37 def make_update_script_for_model(cls, engine, oldmodel,
38 model, repository, **opts):
38 model, repository, **opts):
39 """Create a migration script based on difference between two SA models.
39 """Create a migration script based on difference between two SA models.
40
40
41 :param repository: path to migrate repository
41 :param repository: path to migrate repository
42 :param oldmodel: dotted.module.name:SAClass or SAClass object
42 :param oldmodel: dotted.module.name:SAClass or SAClass object
43 :param model: dotted.module.name:SAClass or SAClass object
43 :param model: dotted.module.name:SAClass or SAClass object
44 :param engine: SQLAlchemy engine
44 :param engine: SQLAlchemy engine
45 :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
45 :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
46 :type oldmodel: string or Class
46 :type oldmodel: string or Class
47 :type model: string or Class
47 :type model: string or Class
48 :type engine: Engine instance
48 :type engine: Engine instance
49 :returns: Upgrade / Downgrade script
49 :returns: Upgrade / Downgrade script
50 :rtype: string
50 :rtype: string
51 """
51 """
52
52
53 if isinstance(repository, basestring):
53 if isinstance(repository, basestring):
54 # oh dear, an import cycle!
54 # oh dear, an import cycle!
55 from migrate.versioning.repository import Repository
55 from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository
56 repository = Repository(repository)
56 repository = Repository(repository)
57
57
58 oldmodel = load_model(oldmodel)
58 oldmodel = load_model(oldmodel)
59 model = load_model(model)
59 model = load_model(model)
60
60
61 # Compute differences.
61 # Compute differences.
62 diff = schemadiff.getDiffOfModelAgainstModel(
62 diff = schemadiff.getDiffOfModelAgainstModel(
63 oldmodel,
63 oldmodel,
64 model,
64 model,
65 excludeTables=[repository.version_table])
65 excludeTables=[repository.version_table])
66 # TODO: diff can be False (there is no difference?)
66 # TODO: diff can be False (there is no difference?)
67 decls, upgradeCommands, downgradeCommands = \
67 decls, upgradeCommands, downgradeCommands = \
68 genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython()
68 genmodel.ModelGenerator(diff, engine).toUpgradeDowngradePython()
69
69
70 # Store differences into file.
70 # Store differences into file.
71 src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
71 src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
72 f = open(src)
72 f = open(src)
73 contents = f.read()
73 contents = f.read()
74 f.close()
74 f.close()
75
75
76 # generate source
76 # generate source
77 search = 'def upgrade(migrate_engine):'
77 search = 'def upgrade(migrate_engine):'
78 contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
78 contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
79 if upgradeCommands:
79 if upgradeCommands:
80 contents = contents.replace(' pass', upgradeCommands, 1)
80 contents = contents.replace(' pass', upgradeCommands, 1)
81 if downgradeCommands:
81 if downgradeCommands:
82 contents = contents.replace(' pass', downgradeCommands, 1)
82 contents = contents.replace(' pass', downgradeCommands, 1)
83 return contents
83 return contents
84
84
85 @classmethod
85 @classmethod
86 def verify_module(cls, path):
86 def verify_module(cls, path):
87 """Ensure path is a valid script
87 """Ensure path is a valid script
88
88
89 :param path: Script location
89 :param path: Script location
90 :type path: string
90 :type path: string
91 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
91 :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
92 :returns: Python module
92 :returns: Python module
93 """
93 """
94 # Try to import and get the upgrade() func
94 # Try to import and get the upgrade() func
95 module = import_path(path)
95 module = import_path(path)
96 try:
96 try:
97 assert callable(module.upgrade)
97 assert callable(module.upgrade)
98 except Exception, e:
98 except Exception, e:
99 raise InvalidScriptError(path + ': %s' % str(e))
99 raise InvalidScriptError(path + ': %s' % str(e))
100 return module
100 return module
101
101
102 def preview_sql(self, url, step, **args):
102 def preview_sql(self, url, step, **args):
103 """Mocks SQLAlchemy Engine to store all executed calls in a string
103 """Mocks SQLAlchemy Engine to store all executed calls in a string
104 and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
104 and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
105
105
106 :returns: SQL file
106 :returns: SQL file
107 """
107 """
108 buf = StringIO()
108 buf = StringIO()
109 args['engine_arg_strategy'] = 'mock'
109 args['engine_arg_strategy'] = 'mock'
110 args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
110 args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
111
111
112 @with_engine
112 @with_engine
113 def go(url, step, **kw):
113 def go(url, step, **kw):
114 engine = kw.pop('engine')
114 engine = kw.pop('engine')
115 self.run(engine, step)
115 self.run(engine, step)
116 return buf.getvalue()
116 return buf.getvalue()
117
117
118 return go(url, step, **args)
118 return go(url, step, **args)
119
119
120 def run(self, engine, step):
120 def run(self, engine, step):
121 """Core method of Script file.
121 """Core method of Script file.
122 Exectues :func:`update` or :func:`downgrade` functions
122 Exectues :func:`update` or :func:`downgrade` functions
123
123
124 :param engine: SQLAlchemy Engine
124 :param engine: SQLAlchemy Engine
125 :param step: Operation to run
125 :param step: Operation to run
126 :type engine: string
126 :type engine: string
127 :type step: int
127 :type step: int
128 """
128 """
129 if step > 0:
129 if step > 0:
130 op = 'upgrade'
130 op = 'upgrade'
131 elif step < 0:
131 elif step < 0:
132 op = 'downgrade'
132 op = 'downgrade'
133 else:
133 else:
134 raise ScriptError("%d is not a valid step" % step)
134 raise ScriptError("%d is not a valid step" % step)
135
135
136 funcname = base.operations[op]
136 funcname = base.operations[op]
137 script_func = self._func(funcname)
137 script_func = self._func(funcname)
138
138
139 try:
139 try:
140 script_func(engine)
140 script_func(engine)
141 except TypeError:
141 except TypeError:
142 warnings.warn("upgrade/downgrade functions must accept engine"
142 warnings.warn("upgrade/downgrade functions must accept engine"
143 " parameter (since version > 0.5.4)", MigrateDeprecationWarning)
143 " parameter (since version > 0.5.4)", MigrateDeprecationWarning)
144 raise
144 raise
145
145
146 @property
146 @property
147 def module(self):
147 def module(self):
148 """Calls :meth:`migrate.versioning.script.py.verify_module`
148 """Calls :meth:`migrate.versioning.script.py.verify_module`
149 and returns it.
149 and returns it.
150 """
150 """
151 if not hasattr(self, '_module'):
151 if not hasattr(self, '_module'):
152 self._module = self.verify_module(self.path)
152 self._module = self.verify_module(self.path)
153 return self._module
153 return self._module
154
154
155 def _func(self, funcname):
155 def _func(self, funcname):
156 if not hasattr(self.module, funcname):
156 if not hasattr(self.module, funcname):
157 msg = "Function '%s' is not defined in this script"
157 msg = "Function '%s' is not defined in this script"
158 raise ScriptError(msg % funcname)
158 raise ScriptError(msg % funcname)
159 return getattr(self.module, funcname)
159 return getattr(self.module, funcname)
@@ -1,49 +1,49 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3 import logging
3 import logging
4 import shutil
4 import shutil
5
5
6 from migrate.versioning.script import base
6 from rhodecode.lib.dbmigrate.migrate.versioning.script import base
7 from migrate.versioning.template import Template
7 from rhodecode.lib.dbmigrate.migrate.versioning.template import Template
8
8
9
9
10 log = logging.getLogger(__name__)
10 log = logging.getLogger(__name__)
11
11
12 class SqlScript(base.BaseScript):
12 class SqlScript(base.BaseScript):
13 """A file containing plain SQL statements."""
13 """A file containing plain SQL statements."""
14
14
15 @classmethod
15 @classmethod
16 def create(cls, path, **opts):
16 def create(cls, path, **opts):
17 """Create an empty migration script at specified path
17 """Create an empty migration script at specified path
18
18
19 :returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
19 :returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
20 cls.require_notfound(path)
20 cls.require_notfound(path)
21
21
22 src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None))
22 src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None))
23 shutil.copy(src, path)
23 shutil.copy(src, path)
24 return cls(path)
24 return cls(path)
25
25
26 # TODO: why is step parameter even here?
26 # TODO: why is step parameter even here?
27 def run(self, engine, step=None, executemany=True):
27 def run(self, engine, step=None, executemany=True):
28 """Runs SQL script through raw dbapi execute call"""
28 """Runs SQL script through raw dbapi execute call"""
29 text = self.source()
29 text = self.source()
30 # Don't rely on SA's autocommit here
30 # Don't rely on SA's autocommit here
31 # (SA uses .startswith to check if a commit is needed. What if script
31 # (SA uses .startswith to check if a commit is needed. What if script
32 # starts with a comment?)
32 # starts with a comment?)
33 conn = engine.connect()
33 conn = engine.connect()
34 try:
34 try:
35 trans = conn.begin()
35 trans = conn.begin()
36 try:
36 try:
37 # HACK: SQLite doesn't allow multiple statements through
37 # HACK: SQLite doesn't allow multiple statements through
38 # its execute() method, but it provides executescript() instead
38 # its execute() method, but it provides executescript() instead
39 dbapi = conn.engine.raw_connection()
39 dbapi = conn.engine.raw_connection()
40 if executemany and getattr(dbapi, 'executescript', None):
40 if executemany and getattr(dbapi, 'executescript', None):
41 dbapi.executescript(text)
41 dbapi.executescript(text)
42 else:
42 else:
43 conn.execute(text)
43 conn.execute(text)
44 trans.commit()
44 trans.commit()
45 except:
45 except:
46 trans.rollback()
46 trans.rollback()
47 raise
47 raise
48 finally:
48 finally:
49 conn.close()
49 conn.close()
@@ -1,215 +1,215 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3
3
4 """The migrate command-line tool."""
4 """The migrate command-line tool."""
5
5
6 import sys
6 import sys
7 import inspect
7 import inspect
8 import logging
8 import logging
9 from optparse import OptionParser, BadOptionError
9 from optparse import OptionParser, BadOptionError
10
10
11 from migrate import exceptions
11 from rhodecode.lib.dbmigrate.migrate import exceptions
12 from migrate.versioning import api
12 from rhodecode.lib.dbmigrate.migrate.versioning import api
13 from migrate.versioning.config import *
13 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
14 from migrate.versioning.util import asbool
14 from rhodecode.lib.dbmigrate.migrate.versioning.util import asbool
15
15
16
16
17 alias = dict(
17 alias = dict(
18 s=api.script,
18 s=api.script,
19 vc=api.version_control,
19 vc=api.version_control,
20 dbv=api.db_version,
20 dbv=api.db_version,
21 v=api.version,
21 v=api.version,
22 )
22 )
23
23
24 def alias_setup():
24 def alias_setup():
25 global alias
25 global alias
26 for key, val in alias.iteritems():
26 for key, val in alias.iteritems():
27 setattr(api, key, val)
27 setattr(api, key, val)
28 alias_setup()
28 alias_setup()
29
29
30
30
31 class PassiveOptionParser(OptionParser):
31 class PassiveOptionParser(OptionParser):
32
32
33 def _process_args(self, largs, rargs, values):
33 def _process_args(self, largs, rargs, values):
34 """little hack to support all --some_option=value parameters"""
34 """little hack to support all --some_option=value parameters"""
35
35
36 while rargs:
36 while rargs:
37 arg = rargs[0]
37 arg = rargs[0]
38 if arg == "--":
38 if arg == "--":
39 del rargs[0]
39 del rargs[0]
40 return
40 return
41 elif arg[0:2] == "--":
41 elif arg[0:2] == "--":
42 # if parser does not know about the option
42 # if parser does not know about the option
43 # pass it along (make it anonymous)
43 # pass it along (make it anonymous)
44 try:
44 try:
45 opt = arg.split('=', 1)[0]
45 opt = arg.split('=', 1)[0]
46 self._match_long_opt(opt)
46 self._match_long_opt(opt)
47 except BadOptionError:
47 except BadOptionError:
48 largs.append(arg)
48 largs.append(arg)
49 del rargs[0]
49 del rargs[0]
50 else:
50 else:
51 self._process_long_opt(rargs, values)
51 self._process_long_opt(rargs, values)
52 elif arg[:1] == "-" and len(arg) > 1:
52 elif arg[:1] == "-" and len(arg) > 1:
53 self._process_short_opts(rargs, values)
53 self._process_short_opts(rargs, values)
54 elif self.allow_interspersed_args:
54 elif self.allow_interspersed_args:
55 largs.append(arg)
55 largs.append(arg)
56 del rargs[0]
56 del rargs[0]
57
57
58 def main(argv=None, **kwargs):
58 def main(argv=None, **kwargs):
59 """Shell interface to :mod:`migrate.versioning.api`.
59 """Shell interface to :mod:`migrate.versioning.api`.
60
60
61 kwargs are default options that can be overriden with passing
61 kwargs are default options that can be overriden with passing
62 --some_option as command line option
62 --some_option as command line option
63
63
64 :param disable_logging: Let migrate configure logging
64 :param disable_logging: Let migrate configure logging
65 :type disable_logging: bool
65 :type disable_logging: bool
66 """
66 """
67 if argv is not None:
67 if argv is not None:
68 argv = argv
68 argv = argv
69 else:
69 else:
70 argv = list(sys.argv[1:])
70 argv = list(sys.argv[1:])
71 commands = list(api.__all__)
71 commands = list(api.__all__)
72 commands.sort()
72 commands.sort()
73
73
74 usage = """%%prog COMMAND ...
74 usage = """%%prog COMMAND ...
75
75
76 Available commands:
76 Available commands:
77 %s
77 %s
78
78
79 Enter "%%prog help COMMAND" for information on a particular command.
79 Enter "%%prog help COMMAND" for information on a particular command.
80 """ % '\n\t'.join(["%s - %s" % (command.ljust(28),
80 """ % '\n\t'.join(["%s - %s" % (command.ljust(28),
81 api.command_desc.get(command)) for command in commands])
81 api.command_desc.get(command)) for command in commands])
82
82
83 parser = PassiveOptionParser(usage=usage)
83 parser = PassiveOptionParser(usage=usage)
84 parser.add_option("-d", "--debug",
84 parser.add_option("-d", "--debug",
85 action="store_true",
85 action="store_true",
86 dest="debug",
86 dest="debug",
87 default=False,
87 default=False,
88 help="Shortcut to turn on DEBUG mode for logging")
88 help="Shortcut to turn on DEBUG mode for logging")
89 parser.add_option("-q", "--disable_logging",
89 parser.add_option("-q", "--disable_logging",
90 action="store_true",
90 action="store_true",
91 dest="disable_logging",
91 dest="disable_logging",
92 default=False,
92 default=False,
93 help="Use this option to disable logging configuration")
93 help="Use this option to disable logging configuration")
94 help_commands = ['help', '-h', '--help']
94 help_commands = ['help', '-h', '--help']
95 HELP = False
95 HELP = False
96
96
97 try:
97 try:
98 command = argv.pop(0)
98 command = argv.pop(0)
99 if command in help_commands:
99 if command in help_commands:
100 HELP = True
100 HELP = True
101 command = argv.pop(0)
101 command = argv.pop(0)
102 except IndexError:
102 except IndexError:
103 parser.print_help()
103 parser.print_help()
104 return
104 return
105
105
106 command_func = getattr(api, command, None)
106 command_func = getattr(api, command, None)
107 if command_func is None or command.startswith('_'):
107 if command_func is None or command.startswith('_'):
108 parser.error("Invalid command %s" % command)
108 parser.error("Invalid command %s" % command)
109
109
110 parser.set_usage(inspect.getdoc(command_func))
110 parser.set_usage(inspect.getdoc(command_func))
111 f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
111 f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
112 for arg in f_args:
112 for arg in f_args:
113 parser.add_option(
113 parser.add_option(
114 "--%s" % arg,
114 "--%s" % arg,
115 dest=arg,
115 dest=arg,
116 action='store',
116 action='store',
117 type="string")
117 type="string")
118
118
119 # display help of the current command
119 # display help of the current command
120 if HELP:
120 if HELP:
121 parser.print_help()
121 parser.print_help()
122 return
122 return
123
123
124 options, args = parser.parse_args(argv)
124 options, args = parser.parse_args(argv)
125
125
126 # override kwargs with anonymous parameters
126 # override kwargs with anonymous parameters
127 override_kwargs = dict()
127 override_kwargs = dict()
128 for arg in list(args):
128 for arg in list(args):
129 if arg.startswith('--'):
129 if arg.startswith('--'):
130 args.remove(arg)
130 args.remove(arg)
131 if '=' in arg:
131 if '=' in arg:
132 opt, value = arg[2:].split('=', 1)
132 opt, value = arg[2:].split('=', 1)
133 else:
133 else:
134 opt = arg[2:]
134 opt = arg[2:]
135 value = True
135 value = True
136 override_kwargs[opt] = value
136 override_kwargs[opt] = value
137
137
138 # override kwargs with options if user is overwriting
138 # override kwargs with options if user is overwriting
139 for key, value in options.__dict__.iteritems():
139 for key, value in options.__dict__.iteritems():
140 if value is not None:
140 if value is not None:
141 override_kwargs[key] = value
141 override_kwargs[key] = value
142
142
143 # arguments that function accepts without passed kwargs
143 # arguments that function accepts without passed kwargs
144 f_required = list(f_args)
144 f_required = list(f_args)
145 candidates = dict(kwargs)
145 candidates = dict(kwargs)
146 candidates.update(override_kwargs)
146 candidates.update(override_kwargs)
147 for key, value in candidates.iteritems():
147 for key, value in candidates.iteritems():
148 if key in f_args:
148 if key in f_args:
149 f_required.remove(key)
149 f_required.remove(key)
150
150
151 # map function arguments to parsed arguments
151 # map function arguments to parsed arguments
152 for arg in args:
152 for arg in args:
153 try:
153 try:
154 kw = f_required.pop(0)
154 kw = f_required.pop(0)
155 except IndexError:
155 except IndexError:
156 parser.error("Too many arguments for command %s: %s" % (command,
156 parser.error("Too many arguments for command %s: %s" % (command,
157 arg))
157 arg))
158 kwargs[kw] = arg
158 kwargs[kw] = arg
159
159
160 # apply overrides
160 # apply overrides
161 kwargs.update(override_kwargs)
161 kwargs.update(override_kwargs)
162
162
163 # configure options
163 # configure options
164 for key, value in options.__dict__.iteritems():
164 for key, value in options.__dict__.iteritems():
165 kwargs.setdefault(key, value)
165 kwargs.setdefault(key, value)
166
166
167 # configure logging
167 # configure logging
168 if not asbool(kwargs.pop('disable_logging', False)):
168 if not asbool(kwargs.pop('disable_logging', False)):
169 # filter to log =< INFO into stdout and rest to stderr
169 # filter to log =< INFO into stdout and rest to stderr
170 class SingleLevelFilter(logging.Filter):
170 class SingleLevelFilter(logging.Filter):
171 def __init__(self, min=None, max=None):
171 def __init__(self, min=None, max=None):
172 self.min = min or 0
172 self.min = min or 0
173 self.max = max or 100
173 self.max = max or 100
174
174
175 def filter(self, record):
175 def filter(self, record):
176 return self.min <= record.levelno <= self.max
176 return self.min <= record.levelno <= self.max
177
177
178 logger = logging.getLogger()
178 logger = logging.getLogger()
179 h1 = logging.StreamHandler(sys.stdout)
179 h1 = logging.StreamHandler(sys.stdout)
180 f1 = SingleLevelFilter(max=logging.INFO)
180 f1 = SingleLevelFilter(max=logging.INFO)
181 h1.addFilter(f1)
181 h1.addFilter(f1)
182 h2 = logging.StreamHandler(sys.stderr)
182 h2 = logging.StreamHandler(sys.stderr)
183 f2 = SingleLevelFilter(min=logging.WARN)
183 f2 = SingleLevelFilter(min=logging.WARN)
184 h2.addFilter(f2)
184 h2.addFilter(f2)
185 logger.addHandler(h1)
185 logger.addHandler(h1)
186 logger.addHandler(h2)
186 logger.addHandler(h2)
187
187
188 if options.debug:
188 if options.debug:
189 logger.setLevel(logging.DEBUG)
189 logger.setLevel(logging.DEBUG)
190 else:
190 else:
191 logger.setLevel(logging.INFO)
191 logger.setLevel(logging.INFO)
192
192
193 log = logging.getLogger(__name__)
193 log = logging.getLogger(__name__)
194
194
195 # check if all args are given
195 # check if all args are given
196 try:
196 try:
197 num_defaults = len(f_defaults)
197 num_defaults = len(f_defaults)
198 except TypeError:
198 except TypeError:
199 num_defaults = 0
199 num_defaults = 0
200 f_args_default = f_args[len(f_args) - num_defaults:]
200 f_args_default = f_args[len(f_args) - num_defaults:]
201 required = list(set(f_required) - set(f_args_default))
201 required = list(set(f_required) - set(f_args_default))
202 if required:
202 if required:
203 parser.error("Not enough arguments for command %s: %s not specified" \
203 parser.error("Not enough arguments for command %s: %s not specified" \
204 % (command, ', '.join(required)))
204 % (command, ', '.join(required)))
205
205
206 # handle command
206 # handle command
207 try:
207 try:
208 ret = command_func(**kwargs)
208 ret = command_func(**kwargs)
209 if ret is not None:
209 if ret is not None:
210 log.info(ret)
210 log.info(ret)
211 except (exceptions.UsageError, exceptions.KnownError), e:
211 except (exceptions.UsageError, exceptions.KnownError), e:
212 parser.error(e.args[0])
212 parser.error(e.args[0])
213
213
214 if __name__ == "__main__":
214 if __name__ == "__main__":
215 main()
215 main()
@@ -1,94 +1,94 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3
3
4 import os
4 import os
5 import shutil
5 import shutil
6 import sys
6 import sys
7
7
8 from pkg_resources import resource_filename
8 from pkg_resources import resource_filename
9
9
10 from migrate.versioning.config import *
10 from rhodecode.lib.dbmigrate.migrate.versioning.config import *
11 from migrate.versioning import pathed
11 from rhodecode.lib.dbmigrate.migrate.versioning import pathed
12
12
13
13
14 class Collection(pathed.Pathed):
14 class Collection(pathed.Pathed):
15 """A collection of templates of a specific type"""
15 """A collection of templates of a specific type"""
16 _mask = None
16 _mask = None
17
17
18 def get_path(self, file):
18 def get_path(self, file):
19 return os.path.join(self.path, str(file))
19 return os.path.join(self.path, str(file))
20
20
21
21
22 class RepositoryCollection(Collection):
22 class RepositoryCollection(Collection):
23 _mask = '%s'
23 _mask = '%s'
24
24
25 class ScriptCollection(Collection):
25 class ScriptCollection(Collection):
26 _mask = '%s.py_tmpl'
26 _mask = '%s.py_tmpl'
27
27
28 class ManageCollection(Collection):
28 class ManageCollection(Collection):
29 _mask = '%s.py_tmpl'
29 _mask = '%s.py_tmpl'
30
30
31 class SQLScriptCollection(Collection):
31 class SQLScriptCollection(Collection):
32 _mask = '%s.py_tmpl'
32 _mask = '%s.py_tmpl'
33
33
34 class Template(pathed.Pathed):
34 class Template(pathed.Pathed):
35 """Finds the paths/packages of various Migrate templates.
35 """Finds the paths/packages of various Migrate templates.
36
36
37 :param path: Templates are loaded from migrate package
37 :param path: Templates are loaded from rhodecode.lib.dbmigrate.migrate package
38 if `path` is not provided.
38 if `path` is not provided.
39 """
39 """
40 pkg = 'migrate.versioning.templates'
40 pkg = 'migrate.versioning.templates'
41 _manage = 'manage.py_tmpl'
41 _manage = 'manage.py_tmpl'
42
42
43 def __new__(cls, path=None):
43 def __new__(cls, path=None):
44 if path is None:
44 if path is None:
45 path = cls._find_path(cls.pkg)
45 path = cls._find_path(cls.pkg)
46 return super(Template, cls).__new__(cls, path)
46 return super(Template, cls).__new__(cls, path)
47
47
48 def __init__(self, path=None):
48 def __init__(self, path=None):
49 if path is None:
49 if path is None:
50 path = Template._find_path(self.pkg)
50 path = Template._find_path(self.pkg)
51 super(Template, self).__init__(path)
51 super(Template, self).__init__(path)
52 self.repository = RepositoryCollection(os.path.join(path, 'repository'))
52 self.repository = RepositoryCollection(os.path.join(path, 'repository'))
53 self.script = ScriptCollection(os.path.join(path, 'script'))
53 self.script = ScriptCollection(os.path.join(path, 'script'))
54 self.manage = ManageCollection(os.path.join(path, 'manage'))
54 self.manage = ManageCollection(os.path.join(path, 'manage'))
55 self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script'))
55 self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script'))
56
56
57 @classmethod
57 @classmethod
58 def _find_path(cls, pkg):
58 def _find_path(cls, pkg):
59 """Returns absolute path to dotted python package."""
59 """Returns absolute path to dotted python package."""
60 tmp_pkg = pkg.rsplit('.', 1)
60 tmp_pkg = pkg.rsplit('.', 1)
61
61
62 if len(tmp_pkg) != 1:
62 if len(tmp_pkg) != 1:
63 return resource_filename(tmp_pkg[0], tmp_pkg[1])
63 return resource_filename(tmp_pkg[0], tmp_pkg[1])
64 else:
64 else:
65 return resource_filename(tmp_pkg[0], '')
65 return resource_filename(tmp_pkg[0], '')
66
66
67 def _get_item(self, collection, theme=None):
67 def _get_item(self, collection, theme=None):
68 """Locates and returns collection.
68 """Locates and returns collection.
69
69
70 :param collection: name of collection to locate
70 :param collection: name of collection to locate
71 :param type_: type of subfolder in collection (defaults to "_default")
71 :param type_: type of subfolder in collection (defaults to "_default")
72 :returns: (package, source)
72 :returns: (package, source)
73 :rtype: str, str
73 :rtype: str, str
74 """
74 """
75 item = getattr(self, collection)
75 item = getattr(self, collection)
76 theme_mask = getattr(item, '_mask')
76 theme_mask = getattr(item, '_mask')
77 theme = theme_mask % (theme or 'default')
77 theme = theme_mask % (theme or 'default')
78 return item.get_path(theme)
78 return item.get_path(theme)
79
79
80 def get_repository(self, *a, **kw):
80 def get_repository(self, *a, **kw):
81 """Calls self._get_item('repository', *a, **kw)"""
81 """Calls self._get_item('repository', *a, **kw)"""
82 return self._get_item('repository', *a, **kw)
82 return self._get_item('repository', *a, **kw)
83
83
84 def get_script(self, *a, **kw):
84 def get_script(self, *a, **kw):
85 """Calls self._get_item('script', *a, **kw)"""
85 """Calls self._get_item('script', *a, **kw)"""
86 return self._get_item('script', *a, **kw)
86 return self._get_item('script', *a, **kw)
87
87
88 def get_sql_script(self, *a, **kw):
88 def get_sql_script(self, *a, **kw):
89 """Calls self._get_item('sql_script', *a, **kw)"""
89 """Calls self._get_item('sql_script', *a, **kw)"""
90 return self._get_item('sql_script', *a, **kw)
90 return self._get_item('sql_script', *a, **kw)
91
91
92 def get_manage(self, *a, **kw):
92 def get_manage(self, *a, **kw):
93 """Calls self._get_item('manage', *a, **kw)"""
93 """Calls self._get_item('manage', *a, **kw)"""
94 return self._get_item('manage', *a, **kw)
94 return self._get_item('manage', *a, **kw)
@@ -1,179 +1,179 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3 """.. currentmodule:: migrate.versioning.util"""
3 """.. currentmodule:: migrate.versioning.util"""
4
4
5 import warnings
5 import warnings
6 import logging
6 import logging
7 from decorator import decorator
7 from decorator import decorator
8 from pkg_resources import EntryPoint
8 from pkg_resources import EntryPoint
9
9
10 from sqlalchemy import create_engine
10 from sqlalchemy import create_engine
11 from sqlalchemy.engine import Engine
11 from sqlalchemy.engine import Engine
12 from sqlalchemy.pool import StaticPool
12 from sqlalchemy.pool import StaticPool
13
13
14 from migrate import exceptions
14 from rhodecode.lib.dbmigrate.migrate import exceptions
15 from migrate.versioning.util.keyedinstance import KeyedInstance
15 from rhodecode.lib.dbmigrate.migrate.versioning.util.keyedinstance import KeyedInstance
16 from migrate.versioning.util.importpath import import_path
16 from rhodecode.lib.dbmigrate.migrate.versioning.util.importpath import import_path
17
17
18
18
19 log = logging.getLogger(__name__)
19 log = logging.getLogger(__name__)
20
20
21 def load_model(dotted_name):
21 def load_model(dotted_name):
22 """Import module and use module-level variable".
22 """Import module and use module-level variable".
23
23
24 :param dotted_name: path to model in form of string: ``some.python.module:Class``
24 :param dotted_name: path to model in form of string: ``some.python.module:Class``
25
25
26 .. versionchanged:: 0.5.4
26 .. versionchanged:: 0.5.4
27
27
28 """
28 """
29 if isinstance(dotted_name, basestring):
29 if isinstance(dotted_name, basestring):
30 if ':' not in dotted_name:
30 if ':' not in dotted_name:
31 # backwards compatibility
31 # backwards compatibility
32 warnings.warn('model should be in form of module.model:User '
32 warnings.warn('model should be in form of module.model:User '
33 'and not module.model.User', exceptions.MigrateDeprecationWarning)
33 'and not module.model.User', exceptions.MigrateDeprecationWarning)
34 dotted_name = ':'.join(dotted_name.rsplit('.', 1))
34 dotted_name = ':'.join(dotted_name.rsplit('.', 1))
35 return EntryPoint.parse('x=%s' % dotted_name).load(False)
35 return EntryPoint.parse('x=%s' % dotted_name).load(False)
36 else:
36 else:
37 # Assume it's already loaded.
37 # Assume it's already loaded.
38 return dotted_name
38 return dotted_name
39
39
40 def asbool(obj):
40 def asbool(obj):
41 """Do everything to use object as bool"""
41 """Do everything to use object as bool"""
42 if isinstance(obj, basestring):
42 if isinstance(obj, basestring):
43 obj = obj.strip().lower()
43 obj = obj.strip().lower()
44 if obj in ['true', 'yes', 'on', 'y', 't', '1']:
44 if obj in ['true', 'yes', 'on', 'y', 't', '1']:
45 return True
45 return True
46 elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
46 elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
47 return False
47 return False
48 else:
48 else:
49 raise ValueError("String is not true/false: %r" % obj)
49 raise ValueError("String is not true/false: %r" % obj)
50 if obj in (True, False):
50 if obj in (True, False):
51 return bool(obj)
51 return bool(obj)
52 else:
52 else:
53 raise ValueError("String is not true/false: %r" % obj)
53 raise ValueError("String is not true/false: %r" % obj)
54
54
55 def guess_obj_type(obj):
55 def guess_obj_type(obj):
56 """Do everything to guess object type from string
56 """Do everything to guess object type from string
57
57
58 Tries to convert to `int`, `bool` and finally returns if not succeded.
58 Tries to convert to `int`, `bool` and finally returns if not succeded.
59
59
60 .. versionadded: 0.5.4
60 .. versionadded: 0.5.4
61 """
61 """
62
62
63 result = None
63 result = None
64
64
65 try:
65 try:
66 result = int(obj)
66 result = int(obj)
67 except:
67 except:
68 pass
68 pass
69
69
70 if result is None:
70 if result is None:
71 try:
71 try:
72 result = asbool(obj)
72 result = asbool(obj)
73 except:
73 except:
74 pass
74 pass
75
75
76 if result is not None:
76 if result is not None:
77 return result
77 return result
78 else:
78 else:
79 return obj
79 return obj
80
80
81 @decorator
81 @decorator
82 def catch_known_errors(f, *a, **kw):
82 def catch_known_errors(f, *a, **kw):
83 """Decorator that catches known api errors
83 """Decorator that catches known api errors
84
84
85 .. versionadded: 0.5.4
85 .. versionadded: 0.5.4
86 """
86 """
87
87
88 try:
88 try:
89 return f(*a, **kw)
89 return f(*a, **kw)
90 except exceptions.PathFoundError, e:
90 except exceptions.PathFoundError, e:
91 raise exceptions.KnownError("The path %s already exists" % e.args[0])
91 raise exceptions.KnownError("The path %s already exists" % e.args[0])
92
92
93 def construct_engine(engine, **opts):
93 def construct_engine(engine, **opts):
94 """.. versionadded:: 0.5.4
94 """.. versionadded:: 0.5.4
95
95
96 Constructs and returns SQLAlchemy engine.
96 Constructs and returns SQLAlchemy engine.
97
97
98 Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions:
98 Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions:
99
99
100 :param engine: connection string or a existing engine
100 :param engine: connection string or a existing engine
101 :param engine_dict: python dictionary of options to pass to `create_engine`
101 :param engine_dict: python dictionary of options to pass to `create_engine`
102 :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`)
102 :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`)
103 :type engine_dict: dict
103 :type engine_dict: dict
104 :type engine: string or Engine instance
104 :type engine: string or Engine instance
105 :type engine_arg_*: string
105 :type engine_arg_*: string
106 :returns: SQLAlchemy Engine
106 :returns: SQLAlchemy Engine
107
107
108 .. note::
108 .. note::
109
109
110 keyword parameters override ``engine_dict`` values.
110 keyword parameters override ``engine_dict`` values.
111
111
112 """
112 """
113 if isinstance(engine, Engine):
113 if isinstance(engine, Engine):
114 return engine
114 return engine
115 elif not isinstance(engine, basestring):
115 elif not isinstance(engine, basestring):
116 raise ValueError("you need to pass either an existing engine or a database uri")
116 raise ValueError("you need to pass either an existing engine or a database uri")
117
117
118 # get options for create_engine
118 # get options for create_engine
119 if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
119 if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
120 kwargs = opts['engine_dict']
120 kwargs = opts['engine_dict']
121 else:
121 else:
122 kwargs = dict()
122 kwargs = dict()
123
123
124 # DEPRECATED: handle echo the old way
124 # DEPRECATED: handle echo the old way
125 echo = asbool(opts.get('echo', False))
125 echo = asbool(opts.get('echo', False))
126 if echo:
126 if echo:
127 warnings.warn('echo=True parameter is deprecated, pass '
127 warnings.warn('echo=True parameter is deprecated, pass '
128 'engine_arg_echo=True or engine_dict={"echo": True}',
128 'engine_arg_echo=True or engine_dict={"echo": True}',
129 exceptions.MigrateDeprecationWarning)
129 exceptions.MigrateDeprecationWarning)
130 kwargs['echo'] = echo
130 kwargs['echo'] = echo
131
131
132 # parse keyword arguments
132 # parse keyword arguments
133 for key, value in opts.iteritems():
133 for key, value in opts.iteritems():
134 if key.startswith('engine_arg_'):
134 if key.startswith('engine_arg_'):
135 kwargs[key[11:]] = guess_obj_type(value)
135 kwargs[key[11:]] = guess_obj_type(value)
136
136
137 log.debug('Constructing engine')
137 log.debug('Constructing engine')
138 # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs)
138 # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs)
139 # seems like 0.5.x branch does not work with engine.dispose and staticpool
139 # seems like 0.5.x branch does not work with engine.dispose and staticpool
140 return create_engine(engine, **kwargs)
140 return create_engine(engine, **kwargs)
141
141
142 @decorator
142 @decorator
143 def with_engine(f, *a, **kw):
143 def with_engine(f, *a, **kw):
144 """Decorator for :mod:`migrate.versioning.api` functions
144 """Decorator for :mod:`migrate.versioning.api` functions
145 to safely close resources after function usage.
145 to safely close resources after function usage.
146
146
147 Passes engine parameters to :func:`construct_engine` and
147 Passes engine parameters to :func:`construct_engine` and
148 resulting parameter is available as kw['engine'].
148 resulting parameter is available as kw['engine'].
149
149
150 Engine is disposed after wrapped function is executed.
150 Engine is disposed after wrapped function is executed.
151
151
152 .. versionadded: 0.6.0
152 .. versionadded: 0.6.0
153 """
153 """
154 url = a[0]
154 url = a[0]
155 engine = construct_engine(url, **kw)
155 engine = construct_engine(url, **kw)
156
156
157 try:
157 try:
158 kw['engine'] = engine
158 kw['engine'] = engine
159 return f(*a, **kw)
159 return f(*a, **kw)
160 finally:
160 finally:
161 if isinstance(engine, Engine):
161 if isinstance(engine, Engine):
162 log.debug('Disposing SQLAlchemy engine %s', engine)
162 log.debug('Disposing SQLAlchemy engine %s', engine)
163 engine.dispose()
163 engine.dispose()
164
164
165
165
166 class Memoize:
166 class Memoize:
167 """Memoize(fn) - an instance which acts like fn but memoizes its arguments
167 """Memoize(fn) - an instance which acts like fn but memoizes its arguments
168 Will only work on functions with non-mutable arguments
168 Will only work on functions with non-mutable arguments
169
169
170 ActiveState Code 52201
170 ActiveState Code 52201
171 """
171 """
172 def __init__(self, fn):
172 def __init__(self, fn):
173 self.fn = fn
173 self.fn = fn
174 self.memo = {}
174 self.memo = {}
175
175
176 def __call__(self, *args):
176 def __call__(self, *args):
177 if not self.memo.has_key(args):
177 if not self.memo.has_key(args):
178 self.memo[args] = self.fn(*args)
178 self.memo[args] = self.fn(*args)
179 return self.memo[args]
179 return self.memo[args]
@@ -1,215 +1,215 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3
3
4 import os
4 import os
5 import re
5 import re
6 import shutil
6 import shutil
7 import logging
7 import logging
8
8
9 from migrate import exceptions
9 from rhodecode.lib.dbmigrate.migrate import exceptions
10 from migrate.versioning import pathed, script
10 from rhodecode.lib.dbmigrate.migrate.versioning import pathed, script
11
11
12
12
13 log = logging.getLogger(__name__)
13 log = logging.getLogger(__name__)
14
14
15 class VerNum(object):
15 class VerNum(object):
16 """A version number that behaves like a string and int at the same time"""
16 """A version number that behaves like a string and int at the same time"""
17
17
18 _instances = dict()
18 _instances = dict()
19
19
20 def __new__(cls, value):
20 def __new__(cls, value):
21 val = str(value)
21 val = str(value)
22 if val not in cls._instances:
22 if val not in cls._instances:
23 cls._instances[val] = super(VerNum, cls).__new__(cls)
23 cls._instances[val] = super(VerNum, cls).__new__(cls)
24 ret = cls._instances[val]
24 ret = cls._instances[val]
25 return ret
25 return ret
26
26
27 def __init__(self,value):
27 def __init__(self,value):
28 self.value = str(int(value))
28 self.value = str(int(value))
29 if self < 0:
29 if self < 0:
30 raise ValueError("Version number cannot be negative")
30 raise ValueError("Version number cannot be negative")
31
31
32 def __add__(self, value):
32 def __add__(self, value):
33 ret = int(self) + int(value)
33 ret = int(self) + int(value)
34 return VerNum(ret)
34 return VerNum(ret)
35
35
36 def __sub__(self, value):
36 def __sub__(self, value):
37 return self + (int(value) * -1)
37 return self + (int(value) * -1)
38
38
39 def __cmp__(self, value):
39 def __cmp__(self, value):
40 return int(self) - int(value)
40 return int(self) - int(value)
41
41
42 def __repr__(self):
42 def __repr__(self):
43 return "<VerNum(%s)>" % self.value
43 return "<VerNum(%s)>" % self.value
44
44
45 def __str__(self):
45 def __str__(self):
46 return str(self.value)
46 return str(self.value)
47
47
48 def __int__(self):
48 def __int__(self):
49 return int(self.value)
49 return int(self.value)
50
50
51
51
52 class Collection(pathed.Pathed):
52 class Collection(pathed.Pathed):
53 """A collection of versioning scripts in a repository"""
53 """A collection of versioning scripts in a repository"""
54
54
55 FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
55 FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
56
56
57 def __init__(self, path):
57 def __init__(self, path):
58 """Collect current version scripts in repository
58 """Collect current version scripts in repository
59 and store them in self.versions
59 and store them in self.versions
60 """
60 """
61 super(Collection, self).__init__(path)
61 super(Collection, self).__init__(path)
62
62
63 # Create temporary list of files, allowing skipped version numbers.
63 # Create temporary list of files, allowing skipped version numbers.
64 files = os.listdir(path)
64 files = os.listdir(path)
65 if '1' in files:
65 if '1' in files:
66 # deprecation
66 # deprecation
67 raise Exception('It looks like you have a repository in the old '
67 raise Exception('It looks like you have a repository in the old '
68 'format (with directories for each version). '
68 'format (with directories for each version). '
69 'Please convert repository before proceeding.')
69 'Please convert repository before proceeding.')
70
70
71 tempVersions = dict()
71 tempVersions = dict()
72 for filename in files:
72 for filename in files:
73 match = self.FILENAME_WITH_VERSION.match(filename)
73 match = self.FILENAME_WITH_VERSION.match(filename)
74 if match:
74 if match:
75 num = int(match.group(1))
75 num = int(match.group(1))
76 tempVersions.setdefault(num, []).append(filename)
76 tempVersions.setdefault(num, []).append(filename)
77 else:
77 else:
78 pass # Must be a helper file or something, let's ignore it.
78 pass # Must be a helper file or something, let's ignore it.
79
79
80 # Create the versions member where the keys
80 # Create the versions member where the keys
81 # are VerNum's and the values are Version's.
81 # are VerNum's and the values are Version's.
82 self.versions = dict()
82 self.versions = dict()
83 for num, files in tempVersions.items():
83 for num, files in tempVersions.items():
84 self.versions[VerNum(num)] = Version(num, path, files)
84 self.versions[VerNum(num)] = Version(num, path, files)
85
85
86 @property
86 @property
87 def latest(self):
87 def latest(self):
88 """:returns: Latest version in Collection"""
88 """:returns: Latest version in Collection"""
89 return max([VerNum(0)] + self.versions.keys())
89 return max([VerNum(0)] + self.versions.keys())
90
90
91 def create_new_python_version(self, description, **k):
91 def create_new_python_version(self, description, **k):
92 """Create Python files for new version"""
92 """Create Python files for new version"""
93 ver = self.latest + 1
93 ver = self.latest + 1
94 extra = str_to_filename(description)
94 extra = str_to_filename(description)
95
95
96 if extra:
96 if extra:
97 if extra == '_':
97 if extra == '_':
98 extra = ''
98 extra = ''
99 elif not extra.startswith('_'):
99 elif not extra.startswith('_'):
100 extra = '_%s' % extra
100 extra = '_%s' % extra
101
101
102 filename = '%03d%s.py' % (ver, extra)
102 filename = '%03d%s.py' % (ver, extra)
103 filepath = self._version_path(filename)
103 filepath = self._version_path(filename)
104
104
105 script.PythonScript.create(filepath, **k)
105 script.PythonScript.create(filepath, **k)
106 self.versions[ver] = Version(ver, self.path, [filename])
106 self.versions[ver] = Version(ver, self.path, [filename])
107
107
108 def create_new_sql_version(self, database, **k):
108 def create_new_sql_version(self, database, **k):
109 """Create SQL files for new version"""
109 """Create SQL files for new version"""
110 ver = self.latest + 1
110 ver = self.latest + 1
111 self.versions[ver] = Version(ver, self.path, [])
111 self.versions[ver] = Version(ver, self.path, [])
112
112
113 # Create new files.
113 # Create new files.
114 for op in ('upgrade', 'downgrade'):
114 for op in ('upgrade', 'downgrade'):
115 filename = '%03d_%s_%s.sql' % (ver, database, op)
115 filename = '%03d_%s_%s.sql' % (ver, database, op)
116 filepath = self._version_path(filename)
116 filepath = self._version_path(filename)
117 script.SqlScript.create(filepath, **k)
117 script.SqlScript.create(filepath, **k)
118 self.versions[ver].add_script(filepath)
118 self.versions[ver].add_script(filepath)
119
119
120 def version(self, vernum=None):
120 def version(self, vernum=None):
121 """Returns latest Version if vernum is not given.
121 """Returns latest Version if vernum is not given.
122 Otherwise, returns wanted version"""
122 Otherwise, returns wanted version"""
123 if vernum is None:
123 if vernum is None:
124 vernum = self.latest
124 vernum = self.latest
125 return self.versions[VerNum(vernum)]
125 return self.versions[VerNum(vernum)]
126
126
127 @classmethod
127 @classmethod
128 def clear(cls):
128 def clear(cls):
129 super(Collection, cls).clear()
129 super(Collection, cls).clear()
130
130
131 def _version_path(self, ver):
131 def _version_path(self, ver):
132 """Returns path of file in versions repository"""
132 """Returns path of file in versions repository"""
133 return os.path.join(self.path, str(ver))
133 return os.path.join(self.path, str(ver))
134
134
135
135
136 class Version(object):
136 class Version(object):
137 """A single version in a collection
137 """A single version in a collection
138 :param vernum: Version Number
138 :param vernum: Version Number
139 :param path: Path to script files
139 :param path: Path to script files
140 :param filelist: List of scripts
140 :param filelist: List of scripts
141 :type vernum: int, VerNum
141 :type vernum: int, VerNum
142 :type path: string
142 :type path: string
143 :type filelist: list
143 :type filelist: list
144 """
144 """
145
145
146 def __init__(self, vernum, path, filelist):
146 def __init__(self, vernum, path, filelist):
147 self.version = VerNum(vernum)
147 self.version = VerNum(vernum)
148
148
149 # Collect scripts in this folder
149 # Collect scripts in this folder
150 self.sql = dict()
150 self.sql = dict()
151 self.python = None
151 self.python = None
152
152
153 for script in filelist:
153 for script in filelist:
154 self.add_script(os.path.join(path, script))
154 self.add_script(os.path.join(path, script))
155
155
156 def script(self, database=None, operation=None):
156 def script(self, database=None, operation=None):
157 """Returns SQL or Python Script"""
157 """Returns SQL or Python Script"""
158 for db in (database, 'default'):
158 for db in (database, 'default'):
159 # Try to return a .sql script first
159 # Try to return a .sql script first
160 try:
160 try:
161 return self.sql[db][operation]
161 return self.sql[db][operation]
162 except KeyError:
162 except KeyError:
163 continue # No .sql script exists
163 continue # No .sql script exists
164
164
165 # TODO: maybe add force Python parameter?
165 # TODO: maybe add force Python parameter?
166 ret = self.python
166 ret = self.python
167
167
168 assert ret is not None, \
168 assert ret is not None, \
169 "There is no script for %d version" % self.version
169 "There is no script for %d version" % self.version
170 return ret
170 return ret
171
171
172 def add_script(self, path):
172 def add_script(self, path):
173 """Add script to Collection/Version"""
173 """Add script to Collection/Version"""
174 if path.endswith(Extensions.py):
174 if path.endswith(Extensions.py):
175 self._add_script_py(path)
175 self._add_script_py(path)
176 elif path.endswith(Extensions.sql):
176 elif path.endswith(Extensions.sql):
177 self._add_script_sql(path)
177 self._add_script_sql(path)
178
178
179 SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
179 SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
180
180
181 def _add_script_sql(self, path):
181 def _add_script_sql(self, path):
182 basename = os.path.basename(path)
182 basename = os.path.basename(path)
183 match = self.SQL_FILENAME.match(basename)
183 match = self.SQL_FILENAME.match(basename)
184
184
185 if match:
185 if match:
186 version, dbms, op = match.group(1), match.group(2), match.group(3)
186 version, dbms, op = match.group(1), match.group(2), match.group(3)
187 else:
187 else:
188 raise exceptions.ScriptError(
188 raise exceptions.ScriptError(
189 "Invalid SQL script name %s " % basename + \
189 "Invalid SQL script name %s " % basename + \
190 "(needs to be ###_database_operation.sql)")
190 "(needs to be ###_database_operation.sql)")
191
191
192 # File the script into a dictionary
192 # File the script into a dictionary
193 self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
193 self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
194
194
195 def _add_script_py(self, path):
195 def _add_script_py(self, path):
196 if self.python is not None:
196 if self.python is not None:
197 raise exceptions.ScriptError('You can only have one Python script '
197 raise exceptions.ScriptError('You can only have one Python script '
198 'per version, but you have: %s and %s' % (self.python, path))
198 'per version, but you have: %s and %s' % (self.python, path))
199 self.python = script.PythonScript(path)
199 self.python = script.PythonScript(path)
200
200
201
201
202 class Extensions:
202 class Extensions:
203 """A namespace for file extensions"""
203 """A namespace for file extensions"""
204 py = 'py'
204 py = 'py'
205 sql = 'sql'
205 sql = 'sql'
206
206
207 def str_to_filename(s):
207 def str_to_filename(s):
208 """Replaces spaces, (double and single) quotes
208 """Replaces spaces, (double and single) quotes
209 and double underscores to underscores
209 and double underscores to underscores
210 """
210 """
211
211
212 s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
212 s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
213 while '__' in s:
213 while '__' in s:
214 s = s.replace('__', '_')
214 s = s.replace('__', '_')
215 return s
215 return s
General Comments 0
You need to be logged in to leave comments. Login now