Show More
@@ -0,0 +1,20 b'' | |||
|
1 | [db_settings] | |
|
2 | # Used to identify which repository this database is versioned under. | |
|
3 | # You can use the name of your project. | |
|
4 | repository_id=rhodecode_db_migrations | |
|
5 | ||
|
6 | # The name of the database table used to track the schema version. | |
|
7 | # This name shouldn't already be used by your project. | |
|
8 | # If this is changed once a database is under version control, you'll need to | |
|
9 | # change the table name in each database too. | |
|
10 | version_table=db_migrate_version | |
|
11 | ||
|
12 | # When committing a change script, Migrate will attempt to generate the | |
|
13 | # sql for all supported databases; normally, if one of them fails - probably | |
|
14 | # because you don't have that database installed - it is ignored and the | |
|
15 | # commit continues, perhaps ending successfully. | |
|
16 | # Databases in this list MUST compile successfully during a commit, or the | |
|
17 | # entire commit will fail. List the databases your application will actually | |
|
18 | # be using to ensure your updates to that database work properly. | |
|
19 | # This must be a list; example: ['postgres','sqlite'] | |
|
20 | required_dbs=['sqlite'] |
@@ -1,328 +1,332 b'' | |||
|
1 | #!/usr/bin/env python | |
|
2 | # encoding: utf-8 | |
|
3 | # database management for RhodeCode | |
|
4 | # Copyright (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com> | |
|
5 | # | |
|
1 | # -*- coding: utf-8 -*- | |
|
2 | """ | |
|
3 | rhodecode.lib.db_manage | |
|
4 | ~~~~~~~~~~~~~~~~~~~~~~~ | |
|
5 | ||
|
6 | Database creation, and setup module for RhodeCode | |
|
7 | ||
|
8 | :created_on: Apr 10, 2010 | |
|
9 | :author: marcink | |
|
10 | :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com> | |
|
11 | :license: GPLv3, see COPYING for more details. | |
|
12 | """ | |
|
6 | 13 | # This program is free software; you can redistribute it and/or |
|
7 | 14 | # modify it under the terms of the GNU General Public License |
|
8 | 15 | # as published by the Free Software Foundation; version 2 |
|
9 | 16 | # of the License or (at your opinion) any later version of the license. |
|
10 | 17 | # |
|
11 | 18 | # This program is distributed in the hope that it will be useful, |
|
12 | 19 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
13 | 20 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
14 | 21 | # GNU General Public License for more details. |
|
15 | 22 | # |
|
16 | 23 | # You should have received a copy of the GNU General Public License |
|
17 | 24 | # along with this program; if not, write to the Free Software |
|
18 | 25 | # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, |
|
19 | 26 | # MA 02110-1301, USA. |
|
20 | 27 | |
|
21 | """ | |
|
22 | Created on April 10, 2010 | |
|
23 | database management and creation for RhodeCode | |
|
24 | @author: marcink | |
|
25 | """ | |
|
26 | ||
|
27 | from os.path import dirname as dn, join as jn | |
|
28 | 28 | import os |
|
29 | 29 | import sys |
|
30 | 30 | import uuid |
|
31 | import logging | |
|
32 | from os.path import dirname as dn, join as jn | |
|
33 | ||
|
34 | from rhodecode import __dbversion__ | |
|
35 | from rhodecode.model.db import | |
|
36 | from rhodecode.model import meta | |
|
31 | 37 | |
|
32 | 38 | from rhodecode.lib.auth import get_crypt_password |
|
33 | 39 | from rhodecode.lib.utils import ask_ok |
|
34 | 40 | from rhodecode.model import init_model |
|
35 | 41 | from rhodecode.model.db import User, Permission, RhodeCodeUi, RhodeCodeSettings, \ |
|
36 | UserToPerm | |
|
37 | from rhodecode.model import meta | |
|
42 | UserToPerm, DbMigrateVersion | |
|
43 | ||
|
38 | 44 | from sqlalchemy.engine import create_engine |
|
39 | import logging | |
|
45 | ||
|
40 | 46 | |
|
41 | 47 | log = logging.getLogger(__name__) |
|
42 | 48 | |
|
43 | 49 | class DbManage(object): |
|
44 | 50 | def __init__(self, log_sql, dbconf, root, tests=False): |
|
45 | 51 | self.dbname = dbconf.split('/')[-1] |
|
46 | 52 | self.tests = tests |
|
47 | 53 | self.root = root |
|
48 | 54 | self.dburi = dbconf |
|
49 | 55 | engine = create_engine(self.dburi, echo=log_sql) |
|
50 | 56 | init_model(engine) |
|
51 | 57 | self.sa = meta.Session() |
|
52 | 58 | self.db_exists = False |
|
53 | 59 | |
|
54 | 60 | def check_for_db(self, override): |
|
55 | 61 | db_path = jn(self.root, self.dbname) |
|
56 | 62 | if self.dburi.startswith('sqlite'): |
|
57 | 63 | log.info('checking for existing db in %s', db_path) |
|
58 | 64 | if os.path.isfile(db_path): |
|
59 | 65 | |
|
60 | 66 | self.db_exists = True |
|
61 | 67 | if not override: |
|
62 | 68 | raise Exception('database already exists') |
|
63 | 69 | |
|
64 | 70 | def create_tables(self, override=False): |
|
65 | 71 | """ |
|
66 | 72 | Create a auth database |
|
67 | 73 | """ |
|
68 | 74 | self.check_for_db(override) |
|
69 | 75 | if self.db_exists: |
|
70 | 76 | log.info("database exist and it's going to be destroyed") |
|
71 | 77 | if self.tests: |
|
72 | 78 | destroy = True |
|
73 | 79 | else: |
|
74 | 80 | destroy = ask_ok('Are you sure to destroy old database ? [y/n]') |
|
75 | 81 | if not destroy: |
|
76 | 82 | sys.exit() |
|
77 | 83 | if self.db_exists and destroy: |
|
78 | 84 | os.remove(jn(self.root, self.dbname)) |
|
79 | 85 | checkfirst = not override |
|
80 | 86 | meta.Base.metadata.create_all(checkfirst=checkfirst) |
|
81 | 87 | log.info('Created tables for %s', self.dbname) |
|
82 | 88 | |
|
83 | 89 | |
|
84 | 90 | |
|
85 | 91 | def set_db_version(self): |
|
86 | from rhodecode import __dbversion__ | |
|
87 | from rhodecode.model.db import DbMigrateVersion | |
|
88 | 92 | try: |
|
89 | 93 | ver = DbMigrateVersion() |
|
90 | 94 | ver.version = __dbversion__ |
|
91 | 95 | ver.repository_id = 'rhodecode_db_migrations' |
|
92 | 96 | ver.repository_path = 'versions' |
|
93 | 97 | self.sa.add(ver) |
|
94 | 98 | self.sa.commit() |
|
95 | 99 | except: |
|
96 | 100 | self.sa.rollback() |
|
97 | 101 | raise |
|
98 | 102 | log.info('db version set to: %s', __dbversion__) |
|
99 | 103 | |
|
100 | 104 | def admin_prompt(self, second=False): |
|
101 | 105 | if not self.tests: |
|
102 | 106 | import getpass |
|
103 | 107 | |
|
104 | 108 | |
|
105 | 109 | def get_password(): |
|
106 | 110 | password = getpass.getpass('Specify admin password (min 6 chars):') |
|
107 | 111 | confirm = getpass.getpass('Confirm password:') |
|
108 | 112 | |
|
109 | 113 | if password != confirm: |
|
110 | 114 | log.error('passwords mismatch') |
|
111 | 115 | return False |
|
112 | 116 | if len(password) < 6: |
|
113 | 117 | log.error('password is to short use at least 6 characters') |
|
114 | 118 | return False |
|
115 | 119 | |
|
116 | 120 | return password |
|
117 | 121 | |
|
118 | 122 | username = raw_input('Specify admin username:') |
|
119 | 123 | |
|
120 | 124 | password = get_password() |
|
121 | 125 | if not password: |
|
122 | 126 | #second try |
|
123 | 127 | password = get_password() |
|
124 | 128 | if not password: |
|
125 | 129 | sys.exit() |
|
126 | 130 | |
|
127 | 131 | email = raw_input('Specify admin email:') |
|
128 | 132 | self.create_user(username, password, email, True) |
|
129 | 133 | else: |
|
130 | 134 | log.info('creating admin and regular test users') |
|
131 | 135 | self.create_user('test_admin', 'test12', 'test_admin@mail.com', True) |
|
132 | 136 | self.create_user('test_regular', 'test12', 'test_regular@mail.com', False) |
|
133 | 137 | self.create_user('test_regular2', 'test12', 'test_regular2@mail.com', False) |
|
134 | 138 | |
|
135 | 139 | |
|
136 | 140 | |
|
137 | 141 | def config_prompt(self, test_repo_path=''): |
|
138 | 142 | log.info('Setting up repositories config') |
|
139 | 143 | |
|
140 | 144 | if not self.tests and not test_repo_path: |
|
141 | 145 | path = raw_input('Specify valid full path to your repositories' |
|
142 | 146 | ' you can change this later in application settings:') |
|
143 | 147 | else: |
|
144 | 148 | path = test_repo_path |
|
145 | 149 | |
|
146 | 150 | if not os.path.isdir(path): |
|
147 | 151 | log.error('You entered wrong path: %s', path) |
|
148 | 152 | sys.exit() |
|
149 | 153 | |
|
150 | 154 | hooks1 = RhodeCodeUi() |
|
151 | 155 | hooks1.ui_section = 'hooks' |
|
152 | 156 | hooks1.ui_key = 'changegroup.update' |
|
153 | 157 | hooks1.ui_value = 'hg update >&2' |
|
154 | 158 | hooks1.ui_active = False |
|
155 | 159 | |
|
156 | 160 | hooks2 = RhodeCodeUi() |
|
157 | 161 | hooks2.ui_section = 'hooks' |
|
158 | 162 | hooks2.ui_key = 'changegroup.repo_size' |
|
159 | 163 | hooks2.ui_value = 'python:rhodecode.lib.hooks.repo_size' |
|
160 | 164 | |
|
161 | 165 | hooks3 = RhodeCodeUi() |
|
162 | 166 | hooks3.ui_section = 'hooks' |
|
163 | 167 | hooks3.ui_key = 'pretxnchangegroup.push_logger' |
|
164 | 168 | hooks3.ui_value = 'python:rhodecode.lib.hooks.log_push_action' |
|
165 | 169 | |
|
166 | 170 | hooks4 = RhodeCodeUi() |
|
167 | 171 | hooks4.ui_section = 'hooks' |
|
168 | 172 | hooks4.ui_key = 'preoutgoing.pull_logger' |
|
169 | 173 | hooks4.ui_value = 'python:rhodecode.lib.hooks.log_pull_action' |
|
170 | 174 | |
|
171 | 175 | #for mercurial 1.7 set backward comapatibility with format |
|
172 | 176 | |
|
173 | 177 | dotencode_disable = RhodeCodeUi() |
|
174 | 178 | dotencode_disable.ui_section = 'format' |
|
175 | 179 | dotencode_disable.ui_key = 'dotencode' |
|
176 | 180 | dotencode_disable.ui_section = 'false' |
|
177 | 181 | |
|
178 | 182 | |
|
179 | 183 | web1 = RhodeCodeUi() |
|
180 | 184 | web1.ui_section = 'web' |
|
181 | 185 | web1.ui_key = 'push_ssl' |
|
182 | 186 | web1.ui_value = 'false' |
|
183 | 187 | |
|
184 | 188 | web2 = RhodeCodeUi() |
|
185 | 189 | web2.ui_section = 'web' |
|
186 | 190 | web2.ui_key = 'allow_archive' |
|
187 | 191 | web2.ui_value = 'gz zip bz2' |
|
188 | 192 | |
|
189 | 193 | web3 = RhodeCodeUi() |
|
190 | 194 | web3.ui_section = 'web' |
|
191 | 195 | web3.ui_key = 'allow_push' |
|
192 | 196 | web3.ui_value = '*' |
|
193 | 197 | |
|
194 | 198 | web4 = RhodeCodeUi() |
|
195 | 199 | web4.ui_section = 'web' |
|
196 | 200 | web4.ui_key = 'baseurl' |
|
197 | 201 | web4.ui_value = '/' |
|
198 | 202 | |
|
199 | 203 | paths = RhodeCodeUi() |
|
200 | 204 | paths.ui_section = 'paths' |
|
201 | 205 | paths.ui_key = '/' |
|
202 | 206 | paths.ui_value = path |
|
203 | 207 | |
|
204 | 208 | |
|
205 | 209 | hgsettings1 = RhodeCodeSettings('realm', 'RhodeCode authentication') |
|
206 | 210 | hgsettings2 = RhodeCodeSettings('title', 'RhodeCode') |
|
207 | 211 | |
|
208 | 212 | |
|
209 | 213 | try: |
|
210 | 214 | self.sa.add(hooks1) |
|
211 | 215 | self.sa.add(hooks2) |
|
212 | 216 | self.sa.add(hooks3) |
|
213 | 217 | self.sa.add(hooks4) |
|
214 | 218 | self.sa.add(web1) |
|
215 | 219 | self.sa.add(web2) |
|
216 | 220 | self.sa.add(web3) |
|
217 | 221 | self.sa.add(web4) |
|
218 | 222 | self.sa.add(paths) |
|
219 | 223 | self.sa.add(hgsettings1) |
|
220 | 224 | self.sa.add(hgsettings2) |
|
221 | 225 | self.sa.add(dotencode_disable) |
|
222 | 226 | for k in ['ldap_active', 'ldap_host', 'ldap_port', 'ldap_ldaps', |
|
223 | 227 | 'ldap_dn_user', 'ldap_dn_pass', 'ldap_base_dn']: |
|
224 | 228 | |
|
225 | 229 | setting = RhodeCodeSettings(k, '') |
|
226 | 230 | self.sa.add(setting) |
|
227 | 231 | |
|
228 | 232 | self.sa.commit() |
|
229 | 233 | except: |
|
230 | 234 | self.sa.rollback() |
|
231 | 235 | raise |
|
232 | 236 | log.info('created ui config') |
|
233 | 237 | |
|
234 | 238 | def create_user(self, username, password, email='', admin=False): |
|
235 | 239 | log.info('creating administrator user %s', username) |
|
236 | 240 | new_user = User() |
|
237 | 241 | new_user.username = username |
|
238 | 242 | new_user.password = get_crypt_password(password) |
|
239 | 243 | new_user.name = 'RhodeCode' |
|
240 | 244 | new_user.lastname = 'Admin' |
|
241 | 245 | new_user.email = email |
|
242 | 246 | new_user.admin = admin |
|
243 | 247 | new_user.active = True |
|
244 | 248 | |
|
245 | 249 | try: |
|
246 | 250 | self.sa.add(new_user) |
|
247 | 251 | self.sa.commit() |
|
248 | 252 | except: |
|
249 | 253 | self.sa.rollback() |
|
250 | 254 | raise |
|
251 | 255 | |
|
252 | 256 | def create_default_user(self): |
|
253 | 257 | log.info('creating default user') |
|
254 | 258 | #create default user for handling default permissions. |
|
255 | 259 | def_user = User() |
|
256 | 260 | def_user.username = 'default' |
|
257 | 261 | def_user.password = get_crypt_password(str(uuid.uuid1())[:8]) |
|
258 | 262 | def_user.name = 'Anonymous' |
|
259 | 263 | def_user.lastname = 'User' |
|
260 | 264 | def_user.email = 'anonymous@rhodecode.org' |
|
261 | 265 | def_user.admin = False |
|
262 | 266 | def_user.active = False |
|
263 | 267 | try: |
|
264 | 268 | self.sa.add(def_user) |
|
265 | 269 | self.sa.commit() |
|
266 | 270 | except: |
|
267 | 271 | self.sa.rollback() |
|
268 | 272 | raise |
|
269 | 273 | |
|
270 | 274 | def create_permissions(self): |
|
271 | 275 | #module.(access|create|change|delete)_[name] |
|
272 | 276 | #module.(read|write|owner) |
|
273 | 277 | perms = [('repository.none', 'Repository no access'), |
|
274 | 278 | ('repository.read', 'Repository read access'), |
|
275 | 279 | ('repository.write', 'Repository write access'), |
|
276 | 280 | ('repository.admin', 'Repository admin access'), |
|
277 | 281 | ('hg.admin', 'Hg Administrator'), |
|
278 | 282 | ('hg.create.repository', 'Repository create'), |
|
279 | 283 | ('hg.create.none', 'Repository creation disabled'), |
|
280 | 284 | ('hg.register.none', 'Register disabled'), |
|
281 | 285 | ('hg.register.manual_activate', 'Register new user with rhodecode without manual activation'), |
|
282 | 286 | ('hg.register.auto_activate', 'Register new user with rhodecode without auto activation'), |
|
283 | 287 | ] |
|
284 | 288 | |
|
285 | 289 | for p in perms: |
|
286 | 290 | new_perm = Permission() |
|
287 | 291 | new_perm.permission_name = p[0] |
|
288 | 292 | new_perm.permission_longname = p[1] |
|
289 | 293 | try: |
|
290 | 294 | self.sa.add(new_perm) |
|
291 | 295 | self.sa.commit() |
|
292 | 296 | except: |
|
293 | 297 | self.sa.rollback() |
|
294 | 298 | raise |
|
295 | 299 | |
|
296 | 300 | def populate_default_permissions(self): |
|
297 | 301 | log.info('creating default user permissions') |
|
298 | 302 | |
|
299 | 303 | default_user = self.sa.query(User)\ |
|
300 | 304 | .filter(User.username == 'default').scalar() |
|
301 | 305 | |
|
302 | 306 | reg_perm = UserToPerm() |
|
303 | 307 | reg_perm.user = default_user |
|
304 | 308 | reg_perm.permission = self.sa.query(Permission)\ |
|
305 | 309 | .filter(Permission.permission_name == 'hg.register.manual_activate')\ |
|
306 | 310 | .scalar() |
|
307 | 311 | |
|
308 | 312 | create_repo_perm = UserToPerm() |
|
309 | 313 | create_repo_perm.user = default_user |
|
310 | 314 | create_repo_perm.permission = self.sa.query(Permission)\ |
|
311 | 315 | .filter(Permission.permission_name == 'hg.create.repository')\ |
|
312 | 316 | .scalar() |
|
313 | 317 | |
|
314 | 318 | default_repo_perm = UserToPerm() |
|
315 | 319 | default_repo_perm.user = default_user |
|
316 | 320 | default_repo_perm.permission = self.sa.query(Permission)\ |
|
317 | 321 | .filter(Permission.permission_name == 'repository.read')\ |
|
318 | 322 | .scalar() |
|
319 | 323 | |
|
320 | 324 | try: |
|
321 | 325 | self.sa.add(reg_perm) |
|
322 | 326 | self.sa.add(create_repo_perm) |
|
323 | 327 | self.sa.add(default_repo_perm) |
|
324 | 328 | self.sa.commit() |
|
325 | 329 | except: |
|
326 | 330 | self.sa.rollback() |
|
327 | 331 | raise |
|
328 | 332 |
@@ -1,59 +1,77 b'' | |||
|
1 | 1 | # -*- coding: utf-8 -*- |
|
2 | 2 | """ |
|
3 | 3 | rhodecode.lib.dbmigrate.__init__ |
|
4 | 4 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
5 | 5 | |
|
6 | 6 | Database migration modules |
|
7 | 7 | |
|
8 | 8 | :created_on: Dec 11, 2010 |
|
9 | 9 | :author: marcink |
|
10 | 10 | :copyright: (C) 2009-2010 Marcin Kuzminski <marcin@python-works.com> |
|
11 | 11 | :license: GPLv3, see COPYING for more details. |
|
12 | 12 | """ |
|
13 | 13 | # This program is free software; you can redistribute it and/or |
|
14 | 14 | # modify it under the terms of the GNU General Public License |
|
15 | 15 | # as published by the Free Software Foundation; version 2 |
|
16 | 16 | # of the License or (at your opinion) any later version of the license. |
|
17 | 17 | # |
|
18 | 18 | # This program is distributed in the hope that it will be useful, |
|
19 | 19 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
20 | 20 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
21 | 21 | # GNU General Public License for more details. |
|
22 | 22 | # |
|
23 | 23 | # You should have received a copy of the GNU General Public License |
|
24 | 24 | # along with this program; if not, write to the Free Software |
|
25 | 25 | # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, |
|
26 | 26 | # MA 02110-1301, USA. |
|
27 | 27 | |
|
28 | from rhodecode.lib.utils import BasePasterCommand | |
|
28 | import logging | |
|
29 | from sqlalchemy import engine_from_config | |
|
30 | ||
|
31 | from rhodecode.lib.dbmigrate.migrate.exceptions import \ | |
|
32 | DatabaseNotControlledError | |
|
29 | 33 | from rhodecode.lib.utils import BasePasterCommand, Command, add_cache |
|
30 | 34 | |
|
31 | from sqlalchemy import engine_from_config | |
|
35 | log = logging.getLogger(__name__) | |
|
32 | 36 | |
|
33 | 37 | class UpgradeDb(BasePasterCommand): |
|
34 | 38 | """Command used for paster to upgrade our database to newer version |
|
35 | 39 | """ |
|
36 | 40 | |
|
37 | 41 | max_args = 1 |
|
38 | 42 | min_args = 1 |
|
39 | 43 | |
|
40 | 44 | usage = "CONFIG_FILE" |
|
41 | 45 | summary = "Upgrades current db to newer version given configuration file" |
|
42 | 46 | group_name = "RhodeCode" |
|
43 | 47 | |
|
44 | 48 | parser = Command.standard_parser(verbose=True) |
|
45 | 49 | |
|
46 | 50 | def command(self): |
|
47 | 51 | from pylons import config |
|
48 | 52 | add_cache(config) |
|
49 | engine = engine_from_config(config, 'sqlalchemy.db1.') | |
|
50 |
|
|
|
51 | raise NotImplementedError('Not implemented yet') | |
|
53 | #engine = engine_from_config(config, 'sqlalchemy.db1.') | |
|
54 | #rint engine | |
|
55 | ||
|
56 | from rhodecode.lib.dbmigrate.migrate.versioning import api | |
|
57 | path = 'rhodecode/lib/dbmigrate' | |
|
58 | ||
|
52 | 59 | |
|
60 | try: | |
|
61 | curr_version = api.db_version(config['sqlalchemy.db1.url'], path) | |
|
62 | msg = ('Found current database under version' | |
|
63 | ' control with version %s' % curr_version) | |
|
64 | ||
|
65 | except (RuntimeError, DatabaseNotControlledError), e: | |
|
66 | curr_version = 0 | |
|
67 | msg = ('Current database is not under version control setting' | |
|
68 | ' as version %s' % curr_version) | |
|
69 | ||
|
70 | print msg | |
|
53 | 71 | |
|
54 | 72 | def update_parser(self): |
|
55 | 73 | self.parser.add_option('--sql', |
|
56 | 74 | action='store_true', |
|
57 | 75 | dest='just_sql', |
|
58 | 76 | help="Prints upgrade sql for further investigation", |
|
59 | 77 | default=False) |
@@ -1,9 +1,9 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for |
|
3 | 3 | database schema version and repository management and |
|
4 | 4 | :mod:`migrate.changeset` that allows to define database schema changes |
|
5 | 5 | using Python. |
|
6 | 6 | """ |
|
7 | 7 | |
|
8 | from migrate.versioning import * | |
|
9 | from migrate.changeset import * | |
|
8 | from rhodecode.lib.dbmigrate.migrate.versioning import * | |
|
9 | from rhodecode.lib.dbmigrate.migrate.changeset import * |
@@ -1,28 +1,28 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | This module extends SQLAlchemy and provides additional DDL [#]_ |
|
3 | 3 | support. |
|
4 | 4 | |
|
5 | 5 | .. [#] SQL Data Definition Language |
|
6 | 6 | """ |
|
7 | 7 | import re |
|
8 | 8 | import warnings |
|
9 | 9 | |
|
10 | 10 | import sqlalchemy |
|
11 | 11 | from sqlalchemy import __version__ as _sa_version |
|
12 | 12 | |
|
13 | 13 | warnings.simplefilter('always', DeprecationWarning) |
|
14 | 14 | |
|
15 | 15 | _sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split(".")) |
|
16 | 16 | SQLA_06 = _sa_version >= (0, 6) |
|
17 | 17 | |
|
18 | 18 | del re |
|
19 | 19 | del _sa_version |
|
20 | 20 | |
|
21 | from migrate.changeset.schema import * | |
|
22 | from migrate.changeset.constraint import * | |
|
21 | from rhodecode.lib.dbmigrate.migrate.changeset.schema import * | |
|
22 | from rhodecode.lib.dbmigrate.migrate.changeset.constraint import * | |
|
23 | 23 | |
|
24 | 24 | sqlalchemy.schema.Table.__bases__ += (ChangesetTable, ) |
|
25 | 25 | sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, ) |
|
26 | 26 | sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, ) |
|
27 | 27 | |
|
28 | 28 | sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause, ) |
@@ -1,358 +1,358 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | Extensions to SQLAlchemy for altering existing tables. |
|
3 | 3 | |
|
4 | 4 | At the moment, this isn't so much based off of ANSI as much as |
|
5 | 5 | things that just happen to work with multiple databases. |
|
6 | 6 | """ |
|
7 | 7 | import StringIO |
|
8 | 8 | |
|
9 | 9 | import sqlalchemy as sa |
|
10 | 10 | from sqlalchemy.schema import SchemaVisitor |
|
11 | 11 | from sqlalchemy.engine.default import DefaultDialect |
|
12 | 12 | from sqlalchemy.sql import ClauseElement |
|
13 | 13 | from sqlalchemy.schema import (ForeignKeyConstraint, |
|
14 | 14 | PrimaryKeyConstraint, |
|
15 | 15 | CheckConstraint, |
|
16 | 16 | UniqueConstraint, |
|
17 | 17 | Index) |
|
18 | 18 | |
|
19 | from migrate import exceptions | |
|
20 | from migrate.changeset import constraint, SQLA_06 | |
|
19 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
20 | from rhodecode.lib.dbmigrate.migrate.changeset import constraint, SQLA_06 | |
|
21 | 21 | |
|
22 | 22 | if not SQLA_06: |
|
23 | 23 | from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper |
|
24 | 24 | else: |
|
25 | 25 | from sqlalchemy.schema import AddConstraint, DropConstraint |
|
26 | 26 | from sqlalchemy.sql.compiler import DDLCompiler |
|
27 | 27 | SchemaGenerator = SchemaDropper = DDLCompiler |
|
28 | 28 | |
|
29 | 29 | |
|
30 | 30 | class AlterTableVisitor(SchemaVisitor): |
|
31 | 31 | """Common operations for ``ALTER TABLE`` statements.""" |
|
32 | 32 | |
|
33 | 33 | if SQLA_06: |
|
34 | 34 | # engine.Compiler looks for .statement |
|
35 | 35 | # when it spawns off a new compiler |
|
36 | 36 | statement = ClauseElement() |
|
37 | 37 | |
|
38 | 38 | def append(self, s): |
|
39 | 39 | """Append content to the SchemaIterator's query buffer.""" |
|
40 | 40 | |
|
41 | 41 | self.buffer.write(s) |
|
42 | 42 | |
|
43 | 43 | def execute(self): |
|
44 | 44 | """Execute the contents of the SchemaIterator's buffer.""" |
|
45 | 45 | try: |
|
46 | 46 | return self.connection.execute(self.buffer.getvalue()) |
|
47 | 47 | finally: |
|
48 | 48 | self.buffer.truncate(0) |
|
49 | 49 | |
|
50 | 50 | def __init__(self, dialect, connection, **kw): |
|
51 | 51 | self.connection = connection |
|
52 | 52 | self.buffer = StringIO.StringIO() |
|
53 | 53 | self.preparer = dialect.identifier_preparer |
|
54 | 54 | self.dialect = dialect |
|
55 | 55 | |
|
56 | 56 | def traverse_single(self, elem): |
|
57 | 57 | ret = super(AlterTableVisitor, self).traverse_single(elem) |
|
58 | 58 | if ret: |
|
59 | 59 | # adapt to 0.6 which uses a string-returning |
|
60 | 60 | # object |
|
61 | 61 | self.append(" %s" % ret) |
|
62 | 62 | |
|
63 | 63 | def _to_table(self, param): |
|
64 | 64 | """Returns the table object for the given param object.""" |
|
65 | 65 | if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)): |
|
66 | 66 | ret = param.table |
|
67 | 67 | else: |
|
68 | 68 | ret = param |
|
69 | 69 | return ret |
|
70 | 70 | |
|
71 | 71 | def start_alter_table(self, param): |
|
72 | 72 | """Returns the start of an ``ALTER TABLE`` SQL-Statement. |
|
73 | 73 | |
|
74 | 74 | Use the param object to determine the table name and use it |
|
75 | 75 | for building the SQL statement. |
|
76 | 76 | |
|
77 | 77 | :param param: object to determine the table from |
|
78 | 78 | :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`, |
|
79 | 79 | :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`, |
|
80 | 80 | or string (table name) |
|
81 | 81 | """ |
|
82 | 82 | table = self._to_table(param) |
|
83 | 83 | self.append('\nALTER TABLE %s ' % self.preparer.format_table(table)) |
|
84 | 84 | return table |
|
85 | 85 | |
|
86 | 86 | |
|
87 | 87 | class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator): |
|
88 | 88 | """Extends ansisql generator for column creation (alter table add col)""" |
|
89 | 89 | |
|
90 | 90 | def visit_column(self, column): |
|
91 | 91 | """Create a column (table already exists). |
|
92 | 92 | |
|
93 | 93 | :param column: column object |
|
94 | 94 | :type column: :class:`sqlalchemy.Column` instance |
|
95 | 95 | """ |
|
96 | 96 | if column.default is not None: |
|
97 | 97 | self.traverse_single(column.default) |
|
98 | 98 | |
|
99 | 99 | table = self.start_alter_table(column) |
|
100 | 100 | self.append("ADD ") |
|
101 | 101 | self.append(self.get_column_specification(column)) |
|
102 | 102 | |
|
103 | 103 | for cons in column.constraints: |
|
104 | 104 | self.traverse_single(cons) |
|
105 | 105 | self.execute() |
|
106 | 106 | |
|
107 | 107 | # ALTER TABLE STATEMENTS |
|
108 | 108 | |
|
109 | 109 | # add indexes and unique constraints |
|
110 | 110 | if column.index_name: |
|
111 | 111 | Index(column.index_name,column).create() |
|
112 | 112 | elif column.unique_name: |
|
113 | 113 | constraint.UniqueConstraint(column, |
|
114 | 114 | name=column.unique_name).create() |
|
115 | 115 | |
|
116 | 116 | # SA bounds FK constraints to table, add manually |
|
117 | 117 | for fk in column.foreign_keys: |
|
118 | 118 | self.add_foreignkey(fk.constraint) |
|
119 | 119 | |
|
120 | 120 | # add primary key constraint if needed |
|
121 | 121 | if column.primary_key_name: |
|
122 | 122 | cons = constraint.PrimaryKeyConstraint(column, |
|
123 | 123 | name=column.primary_key_name) |
|
124 | 124 | cons.create() |
|
125 | 125 | |
|
126 | 126 | if SQLA_06: |
|
127 | 127 | def add_foreignkey(self, fk): |
|
128 | 128 | self.connection.execute(AddConstraint(fk)) |
|
129 | 129 | |
|
130 | 130 | class ANSIColumnDropper(AlterTableVisitor, SchemaDropper): |
|
131 | 131 | """Extends ANSI SQL dropper for column dropping (``ALTER TABLE |
|
132 | 132 | DROP COLUMN``). |
|
133 | 133 | """ |
|
134 | 134 | |
|
135 | 135 | def visit_column(self, column): |
|
136 | 136 | """Drop a column from its table. |
|
137 | 137 | |
|
138 | 138 | :param column: the column object |
|
139 | 139 | :type column: :class:`sqlalchemy.Column` |
|
140 | 140 | """ |
|
141 | 141 | table = self.start_alter_table(column) |
|
142 | 142 | self.append('DROP COLUMN %s' % self.preparer.format_column(column)) |
|
143 | 143 | self.execute() |
|
144 | 144 | |
|
145 | 145 | |
|
146 | 146 | class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator): |
|
147 | 147 | """Manages changes to existing schema elements. |
|
148 | 148 | |
|
149 | 149 | Note that columns are schema elements; ``ALTER TABLE ADD COLUMN`` |
|
150 | 150 | is in SchemaGenerator. |
|
151 | 151 | |
|
152 | 152 | All items may be renamed. Columns can also have many of their properties - |
|
153 | 153 | type, for example - changed. |
|
154 | 154 | |
|
155 | 155 | Each function is passed a tuple, containing (object, name); where |
|
156 | 156 | object is a type of object you'd expect for that function |
|
157 | 157 | (ie. table for visit_table) and name is the object's new |
|
158 | 158 | name. NONE means the name is unchanged. |
|
159 | 159 | """ |
|
160 | 160 | |
|
161 | 161 | def visit_table(self, table): |
|
162 | 162 | """Rename a table. Other ops aren't supported.""" |
|
163 | 163 | self.start_alter_table(table) |
|
164 | 164 | self.append("RENAME TO %s" % self.preparer.quote(table.new_name, |
|
165 | 165 | table.quote)) |
|
166 | 166 | self.execute() |
|
167 | 167 | |
|
168 | 168 | def visit_index(self, index): |
|
169 | 169 | """Rename an index""" |
|
170 | 170 | if hasattr(self, '_validate_identifier'): |
|
171 | 171 | # SA <= 0.6.3 |
|
172 | 172 | self.append("ALTER INDEX %s RENAME TO %s" % ( |
|
173 | 173 | self.preparer.quote( |
|
174 | 174 | self._validate_identifier( |
|
175 | 175 | index.name, True), index.quote), |
|
176 | 176 | self.preparer.quote( |
|
177 | 177 | self._validate_identifier( |
|
178 | 178 | index.new_name, True), index.quote))) |
|
179 | 179 | else: |
|
180 | 180 | # SA >= 0.6.5 |
|
181 | 181 | self.append("ALTER INDEX %s RENAME TO %s" % ( |
|
182 | 182 | self.preparer.quote( |
|
183 | 183 | self._index_identifier( |
|
184 | 184 | index.name), index.quote), |
|
185 | 185 | self.preparer.quote( |
|
186 | 186 | self._index_identifier( |
|
187 | 187 | index.new_name), index.quote))) |
|
188 | 188 | self.execute() |
|
189 | 189 | |
|
190 | 190 | def visit_column(self, delta): |
|
191 | 191 | """Rename/change a column.""" |
|
192 | 192 | # ALTER COLUMN is implemented as several ALTER statements |
|
193 | 193 | keys = delta.keys() |
|
194 | 194 | if 'type' in keys: |
|
195 | 195 | self._run_subvisit(delta, self._visit_column_type) |
|
196 | 196 | if 'nullable' in keys: |
|
197 | 197 | self._run_subvisit(delta, self._visit_column_nullable) |
|
198 | 198 | if 'server_default' in keys: |
|
199 | 199 | # Skip 'default': only handle server-side defaults, others |
|
200 | 200 | # are managed by the app, not the db. |
|
201 | 201 | self._run_subvisit(delta, self._visit_column_default) |
|
202 | 202 | if 'name' in keys: |
|
203 | 203 | self._run_subvisit(delta, self._visit_column_name, start_alter=False) |
|
204 | 204 | |
|
205 | 205 | def _run_subvisit(self, delta, func, start_alter=True): |
|
206 | 206 | """Runs visit method based on what needs to be changed on column""" |
|
207 | 207 | table = self._to_table(delta.table) |
|
208 | 208 | col_name = delta.current_name |
|
209 | 209 | if start_alter: |
|
210 | 210 | self.start_alter_column(table, col_name) |
|
211 | 211 | ret = func(table, delta.result_column, delta) |
|
212 | 212 | self.execute() |
|
213 | 213 | |
|
214 | 214 | def start_alter_column(self, table, col_name): |
|
215 | 215 | """Starts ALTER COLUMN""" |
|
216 | 216 | self.start_alter_table(table) |
|
217 | 217 | self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, table.quote)) |
|
218 | 218 | |
|
219 | 219 | def _visit_column_nullable(self, table, column, delta): |
|
220 | 220 | nullable = delta['nullable'] |
|
221 | 221 | if nullable: |
|
222 | 222 | self.append("DROP NOT NULL") |
|
223 | 223 | else: |
|
224 | 224 | self.append("SET NOT NULL") |
|
225 | 225 | |
|
226 | 226 | def _visit_column_default(self, table, column, delta): |
|
227 | 227 | default_text = self.get_column_default_string(column) |
|
228 | 228 | if default_text is not None: |
|
229 | 229 | self.append("SET DEFAULT %s" % default_text) |
|
230 | 230 | else: |
|
231 | 231 | self.append("DROP DEFAULT") |
|
232 | 232 | |
|
233 | 233 | def _visit_column_type(self, table, column, delta): |
|
234 | 234 | type_ = delta['type'] |
|
235 | 235 | if SQLA_06: |
|
236 | 236 | type_text = str(type_.compile(dialect=self.dialect)) |
|
237 | 237 | else: |
|
238 | 238 | type_text = type_.dialect_impl(self.dialect).get_col_spec() |
|
239 | 239 | self.append("TYPE %s" % type_text) |
|
240 | 240 | |
|
241 | 241 | def _visit_column_name(self, table, column, delta): |
|
242 | 242 | self.start_alter_table(table) |
|
243 | 243 | col_name = self.preparer.quote(delta.current_name, table.quote) |
|
244 | 244 | new_name = self.preparer.format_column(delta.result_column) |
|
245 | 245 | self.append('RENAME COLUMN %s TO %s' % (col_name, new_name)) |
|
246 | 246 | |
|
247 | 247 | |
|
248 | 248 | class ANSIConstraintCommon(AlterTableVisitor): |
|
249 | 249 | """ |
|
250 | 250 | Migrate's constraints require a separate creation function from |
|
251 | 251 | SA's: Migrate's constraints are created independently of a table; |
|
252 | 252 | SA's are created at the same time as the table. |
|
253 | 253 | """ |
|
254 | 254 | |
|
255 | 255 | def get_constraint_name(self, cons): |
|
256 | 256 | """Gets a name for the given constraint. |
|
257 | 257 | |
|
258 | 258 | If the name is already set it will be used otherwise the |
|
259 | 259 | constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>` |
|
260 | 260 | method is used. |
|
261 | 261 | |
|
262 | 262 | :param cons: constraint object |
|
263 | 263 | """ |
|
264 | 264 | if cons.name is not None: |
|
265 | 265 | ret = cons.name |
|
266 | 266 | else: |
|
267 | 267 | ret = cons.name = cons.autoname() |
|
268 | 268 | return self.preparer.quote(ret, cons.quote) |
|
269 | 269 | |
|
270 | 270 | def visit_migrate_primary_key_constraint(self, *p, **k): |
|
271 | 271 | self._visit_constraint(*p, **k) |
|
272 | 272 | |
|
273 | 273 | def visit_migrate_foreign_key_constraint(self, *p, **k): |
|
274 | 274 | self._visit_constraint(*p, **k) |
|
275 | 275 | |
|
276 | 276 | def visit_migrate_check_constraint(self, *p, **k): |
|
277 | 277 | self._visit_constraint(*p, **k) |
|
278 | 278 | |
|
279 | 279 | def visit_migrate_unique_constraint(self, *p, **k): |
|
280 | 280 | self._visit_constraint(*p, **k) |
|
281 | 281 | |
|
282 | 282 | if SQLA_06: |
|
283 | 283 | class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator): |
|
284 | 284 | def _visit_constraint(self, constraint): |
|
285 | 285 | constraint.name = self.get_constraint_name(constraint) |
|
286 | 286 | self.append(self.process(AddConstraint(constraint))) |
|
287 | 287 | self.execute() |
|
288 | 288 | |
|
289 | 289 | class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper): |
|
290 | 290 | def _visit_constraint(self, constraint): |
|
291 | 291 | constraint.name = self.get_constraint_name(constraint) |
|
292 | 292 | self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade))) |
|
293 | 293 | self.execute() |
|
294 | 294 | |
|
295 | 295 | else: |
|
296 | 296 | class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator): |
|
297 | 297 | |
|
298 | 298 | def get_constraint_specification(self, cons, **kwargs): |
|
299 | 299 | """Constaint SQL generators. |
|
300 | 300 | |
|
301 | 301 | We cannot use SA visitors because they append comma. |
|
302 | 302 | """ |
|
303 | 303 | |
|
304 | 304 | if isinstance(cons, PrimaryKeyConstraint): |
|
305 | 305 | if cons.name is not None: |
|
306 | 306 | self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons)) |
|
307 | 307 | self.append("PRIMARY KEY ") |
|
308 | 308 | self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) |
|
309 | 309 | for c in cons)) |
|
310 | 310 | self.define_constraint_deferrability(cons) |
|
311 | 311 | elif isinstance(cons, ForeignKeyConstraint): |
|
312 | 312 | self.define_foreign_key(cons) |
|
313 | 313 | elif isinstance(cons, CheckConstraint): |
|
314 | 314 | if cons.name is not None: |
|
315 | 315 | self.append("CONSTRAINT %s " % |
|
316 | 316 | self.preparer.format_constraint(cons)) |
|
317 | 317 | self.append("CHECK (%s)" % cons.sqltext) |
|
318 | 318 | self.define_constraint_deferrability(cons) |
|
319 | 319 | elif isinstance(cons, UniqueConstraint): |
|
320 | 320 | if cons.name is not None: |
|
321 | 321 | self.append("CONSTRAINT %s " % |
|
322 | 322 | self.preparer.format_constraint(cons)) |
|
323 | 323 | self.append("UNIQUE (%s)" % \ |
|
324 | 324 | (', '.join(self.preparer.quote(c.name, c.quote) for c in cons))) |
|
325 | 325 | self.define_constraint_deferrability(cons) |
|
326 | 326 | else: |
|
327 | 327 | raise exceptions.InvalidConstraintError(cons) |
|
328 | 328 | |
|
329 | 329 | def _visit_constraint(self, constraint): |
|
330 | 330 | |
|
331 | 331 | table = self.start_alter_table(constraint) |
|
332 | 332 | constraint.name = self.get_constraint_name(constraint) |
|
333 | 333 | self.append("ADD ") |
|
334 | 334 | self.get_constraint_specification(constraint) |
|
335 | 335 | self.execute() |
|
336 | 336 | |
|
337 | 337 | |
|
338 | 338 | class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper): |
|
339 | 339 | |
|
340 | 340 | def _visit_constraint(self, constraint): |
|
341 | 341 | self.start_alter_table(constraint) |
|
342 | 342 | self.append("DROP CONSTRAINT ") |
|
343 | 343 | constraint.name = self.get_constraint_name(constraint) |
|
344 | 344 | self.append(self.preparer.format_constraint(constraint)) |
|
345 | 345 | if constraint.cascade: |
|
346 | 346 | self.cascade_constraint(constraint) |
|
347 | 347 | self.execute() |
|
348 | 348 | |
|
349 | 349 | def cascade_constraint(self, constraint): |
|
350 | 350 | self.append(" CASCADE") |
|
351 | 351 | |
|
352 | 352 | |
|
353 | 353 | class ANSIDialect(DefaultDialect): |
|
354 | 354 | columngenerator = ANSIColumnGenerator |
|
355 | 355 | columndropper = ANSIColumnDropper |
|
356 | 356 | schemachanger = ANSISchemaChanger |
|
357 | 357 | constraintgenerator = ANSIConstraintGenerator |
|
358 | 358 | constraintdropper = ANSIConstraintDropper |
@@ -1,202 +1,202 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | This module defines standalone schema constraint classes. |
|
3 | 3 | """ |
|
4 | 4 | from sqlalchemy import schema |
|
5 | 5 | |
|
6 | from migrate.exceptions import * | |
|
7 | from migrate.changeset import SQLA_06 | |
|
6 | from rhodecode.lib.dbmigrate.migrate.exceptions import * | |
|
7 | from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06 | |
|
8 | 8 | |
|
9 | 9 | class ConstraintChangeset(object): |
|
10 | 10 | """Base class for Constraint classes.""" |
|
11 | 11 | |
|
12 | 12 | def _normalize_columns(self, cols, table_name=False): |
|
13 | 13 | """Given: column objects or names; return col names and |
|
14 | 14 | (maybe) a table""" |
|
15 | 15 | colnames = [] |
|
16 | 16 | table = None |
|
17 | 17 | for col in cols: |
|
18 | 18 | if isinstance(col, schema.Column): |
|
19 | 19 | if col.table is not None and table is None: |
|
20 | 20 | table = col.table |
|
21 | 21 | if table_name: |
|
22 | 22 | col = '.'.join((col.table.name, col.name)) |
|
23 | 23 | else: |
|
24 | 24 | col = col.name |
|
25 | 25 | colnames.append(col) |
|
26 | 26 | return colnames, table |
|
27 | 27 | |
|
28 | 28 | def __do_imports(self, visitor_name, *a, **kw): |
|
29 | 29 | engine = kw.pop('engine', self.table.bind) |
|
30 | from migrate.changeset.databases.visitor import (get_engine_visitor, | |
|
30 | from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor, | |
|
31 | 31 | run_single_visitor) |
|
32 | 32 | visitorcallable = get_engine_visitor(engine, visitor_name) |
|
33 | 33 | run_single_visitor(engine, visitorcallable, self, *a, **kw) |
|
34 | 34 | |
|
35 | 35 | def create(self, *a, **kw): |
|
36 | 36 | """Create the constraint in the database. |
|
37 | 37 | |
|
38 | 38 | :param engine: the database engine to use. If this is \ |
|
39 | 39 | :keyword:`None` the instance's engine will be used |
|
40 | 40 | :type engine: :class:`sqlalchemy.engine.base.Engine` |
|
41 | 41 | :param connection: reuse connection istead of creating new one. |
|
42 | 42 | :type connection: :class:`sqlalchemy.engine.base.Connection` instance |
|
43 | 43 | """ |
|
44 | 44 | # TODO: set the parent here instead of in __init__ |
|
45 | 45 | self.__do_imports('constraintgenerator', *a, **kw) |
|
46 | 46 | |
|
47 | 47 | def drop(self, *a, **kw): |
|
48 | 48 | """Drop the constraint from the database. |
|
49 | 49 | |
|
50 | 50 | :param engine: the database engine to use. If this is |
|
51 | 51 | :keyword:`None` the instance's engine will be used |
|
52 | 52 | :param cascade: Issue CASCADE drop if database supports it |
|
53 | 53 | :type engine: :class:`sqlalchemy.engine.base.Engine` |
|
54 | 54 | :type cascade: bool |
|
55 | 55 | :param connection: reuse connection istead of creating new one. |
|
56 | 56 | :type connection: :class:`sqlalchemy.engine.base.Connection` instance |
|
57 | 57 | :returns: Instance with cleared columns |
|
58 | 58 | """ |
|
59 | 59 | self.cascade = kw.pop('cascade', False) |
|
60 | 60 | self.__do_imports('constraintdropper', *a, **kw) |
|
61 | 61 | # the spirit of Constraint objects is that they |
|
62 | 62 | # are immutable (just like in a DB. they're only ADDed |
|
63 | 63 | # or DROPped). |
|
64 | 64 | #self.columns.clear() |
|
65 | 65 | return self |
|
66 | 66 | |
|
67 | 67 | |
|
68 | 68 | class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint): |
|
69 | 69 | """Construct PrimaryKeyConstraint |
|
70 | 70 | |
|
71 | 71 | Migrate's additional parameters: |
|
72 | 72 | |
|
73 | 73 | :param cols: Columns in constraint. |
|
74 | 74 | :param table: If columns are passed as strings, this kw is required |
|
75 | 75 | :type table: Table instance |
|
76 | 76 | :type cols: strings or Column instances |
|
77 | 77 | """ |
|
78 | 78 | |
|
79 | 79 | __migrate_visit_name__ = 'migrate_primary_key_constraint' |
|
80 | 80 | |
|
81 | 81 | def __init__(self, *cols, **kwargs): |
|
82 | 82 | colnames, table = self._normalize_columns(cols) |
|
83 | 83 | table = kwargs.pop('table', table) |
|
84 | 84 | super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs) |
|
85 | 85 | if table is not None: |
|
86 | 86 | self._set_parent(table) |
|
87 | 87 | |
|
88 | 88 | |
|
89 | 89 | def autoname(self): |
|
90 | 90 | """Mimic the database's automatic constraint names""" |
|
91 | 91 | return "%s_pkey" % self.table.name |
|
92 | 92 | |
|
93 | 93 | |
|
94 | 94 | class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint): |
|
95 | 95 | """Construct ForeignKeyConstraint |
|
96 | 96 | |
|
97 | 97 | Migrate's additional parameters: |
|
98 | 98 | |
|
99 | 99 | :param columns: Columns in constraint |
|
100 | 100 | :param refcolumns: Columns that this FK reffers to in another table. |
|
101 | 101 | :param table: If columns are passed as strings, this kw is required |
|
102 | 102 | :type table: Table instance |
|
103 | 103 | :type columns: list of strings or Column instances |
|
104 | 104 | :type refcolumns: list of strings or Column instances |
|
105 | 105 | """ |
|
106 | 106 | |
|
107 | 107 | __migrate_visit_name__ = 'migrate_foreign_key_constraint' |
|
108 | 108 | |
|
109 | 109 | def __init__(self, columns, refcolumns, *args, **kwargs): |
|
110 | 110 | colnames, table = self._normalize_columns(columns) |
|
111 | 111 | table = kwargs.pop('table', table) |
|
112 | 112 | refcolnames, reftable = self._normalize_columns(refcolumns, |
|
113 | 113 | table_name=True) |
|
114 | 114 | super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *args, |
|
115 | 115 | **kwargs) |
|
116 | 116 | if table is not None: |
|
117 | 117 | self._set_parent(table) |
|
118 | 118 | |
|
119 | 119 | @property |
|
120 | 120 | def referenced(self): |
|
121 | 121 | return [e.column for e in self.elements] |
|
122 | 122 | |
|
123 | 123 | @property |
|
124 | 124 | def reftable(self): |
|
125 | 125 | return self.referenced[0].table |
|
126 | 126 | |
|
127 | 127 | def autoname(self): |
|
128 | 128 | """Mimic the database's automatic constraint names""" |
|
129 | 129 | if hasattr(self.columns, 'keys'): |
|
130 | 130 | # SA <= 0.5 |
|
131 | 131 | firstcol = self.columns[self.columns.keys()[0]] |
|
132 | 132 | ret = "%(table)s_%(firstcolumn)s_fkey" % dict( |
|
133 | 133 | table=firstcol.table.name, |
|
134 | 134 | firstcolumn=firstcol.name,) |
|
135 | 135 | else: |
|
136 | 136 | # SA >= 0.6 |
|
137 | 137 | ret = "%(table)s_%(firstcolumn)s_fkey" % dict( |
|
138 | 138 | table=self.table.name, |
|
139 | 139 | firstcolumn=self.columns[0],) |
|
140 | 140 | return ret |
|
141 | 141 | |
|
142 | 142 | |
|
143 | 143 | class CheckConstraint(ConstraintChangeset, schema.CheckConstraint): |
|
144 | 144 | """Construct CheckConstraint |
|
145 | 145 | |
|
146 | 146 | Migrate's additional parameters: |
|
147 | 147 | |
|
148 | 148 | :param sqltext: Plain SQL text to check condition |
|
149 | 149 | :param columns: If not name is applied, you must supply this kw\ |
|
150 | 150 | to autoname constraint |
|
151 | 151 | :param table: If columns are passed as strings, this kw is required |
|
152 | 152 | :type table: Table instance |
|
153 | 153 | :type columns: list of Columns instances |
|
154 | 154 | :type sqltext: string |
|
155 | 155 | """ |
|
156 | 156 | |
|
157 | 157 | __migrate_visit_name__ = 'migrate_check_constraint' |
|
158 | 158 | |
|
159 | 159 | def __init__(self, sqltext, *args, **kwargs): |
|
160 | 160 | cols = kwargs.pop('columns', []) |
|
161 | 161 | if not cols and not kwargs.get('name', False): |
|
162 | 162 | raise InvalidConstraintError('You must either set "name"' |
|
163 | 163 | 'parameter or "columns" to autogenarate it.') |
|
164 | 164 | colnames, table = self._normalize_columns(cols) |
|
165 | 165 | table = kwargs.pop('table', table) |
|
166 | 166 | schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs) |
|
167 | 167 | if table is not None: |
|
168 | 168 | if not SQLA_06: |
|
169 | 169 | self.table = table |
|
170 | 170 | self._set_parent(table) |
|
171 | 171 | self.colnames = colnames |
|
172 | 172 | |
|
173 | 173 | def autoname(self): |
|
174 | 174 | return "%(table)s_%(cols)s_check" % \ |
|
175 | 175 | dict(table=self.table.name, cols="_".join(self.colnames)) |
|
176 | 176 | |
|
177 | 177 | |
|
178 | 178 | class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint): |
|
179 | 179 | """Construct UniqueConstraint |
|
180 | 180 | |
|
181 | 181 | Migrate's additional parameters: |
|
182 | 182 | |
|
183 | 183 | :param cols: Columns in constraint. |
|
184 | 184 | :param table: If columns are passed as strings, this kw is required |
|
185 | 185 | :type table: Table instance |
|
186 | 186 | :type cols: strings or Column instances |
|
187 | 187 | |
|
188 | 188 | .. versionadded:: 0.6.0 |
|
189 | 189 | """ |
|
190 | 190 | |
|
191 | 191 | __migrate_visit_name__ = 'migrate_unique_constraint' |
|
192 | 192 | |
|
193 | 193 | def __init__(self, *cols, **kwargs): |
|
194 | 194 | self.colnames, table = self._normalize_columns(cols) |
|
195 | 195 | table = kwargs.pop('table', table) |
|
196 | 196 | super(UniqueConstraint, self).__init__(*self.colnames, **kwargs) |
|
197 | 197 | if table is not None: |
|
198 | 198 | self._set_parent(table) |
|
199 | 199 | |
|
200 | 200 | def autoname(self): |
|
201 | 201 | """Mimic the database's automatic constraint names""" |
|
202 | 202 | return "%s_%s_key" % (self.table.name, self.colnames[0]) |
@@ -1,80 +1,80 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | Firebird database specific implementations of changeset classes. |
|
3 | 3 | """ |
|
4 | 4 | from sqlalchemy.databases import firebird as sa_base |
|
5 | 5 | |
|
6 | from migrate import exceptions | |
|
7 | from migrate.changeset import ansisql, SQLA_06 | |
|
6 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
7 | from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06 | |
|
8 | 8 | |
|
9 | 9 | |
|
10 | 10 | if SQLA_06: |
|
11 | 11 | FBSchemaGenerator = sa_base.FBDDLCompiler |
|
12 | 12 | else: |
|
13 | 13 | FBSchemaGenerator = sa_base.FBSchemaGenerator |
|
14 | 14 | |
|
15 | 15 | class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator): |
|
16 | 16 | """Firebird column generator implementation.""" |
|
17 | 17 | |
|
18 | 18 | |
|
19 | 19 | class FBColumnDropper(ansisql.ANSIColumnDropper): |
|
20 | 20 | """Firebird column dropper implementation.""" |
|
21 | 21 | |
|
22 | 22 | def visit_column(self, column): |
|
23 | 23 | """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax |
|
24 | 24 | |
|
25 | 25 | Drop primary key and unique constraints if dropped column is referencing it.""" |
|
26 | 26 | if column.primary_key: |
|
27 | 27 | if column.table.primary_key.columns.contains_column(column): |
|
28 | 28 | column.table.primary_key.drop() |
|
29 | 29 | # TODO: recreate primary key if it references more than this column |
|
30 | 30 | if column.unique or getattr(column, 'unique_name', None): |
|
31 | 31 | for cons in column.table.constraints: |
|
32 | 32 | if cons.contains_column(column): |
|
33 | 33 | cons.drop() |
|
34 | 34 | # TODO: recreate unique constraint if it refenrences more than this column |
|
35 | 35 | |
|
36 | 36 | table = self.start_alter_table(column) |
|
37 | 37 | self.append('DROP %s' % self.preparer.format_column(column)) |
|
38 | 38 | self.execute() |
|
39 | 39 | |
|
40 | 40 | |
|
41 | 41 | class FBSchemaChanger(ansisql.ANSISchemaChanger): |
|
42 | 42 | """Firebird schema changer implementation.""" |
|
43 | 43 | |
|
44 | 44 | def visit_table(self, table): |
|
45 | 45 | """Rename table not supported""" |
|
46 | 46 | raise exceptions.NotSupportedError( |
|
47 | 47 | "Firebird does not support renaming tables.") |
|
48 | 48 | |
|
49 | 49 | def _visit_column_name(self, table, column, delta): |
|
50 | 50 | self.start_alter_table(table) |
|
51 | 51 | col_name = self.preparer.quote(delta.current_name, table.quote) |
|
52 | 52 | new_name = self.preparer.format_column(delta.result_column) |
|
53 | 53 | self.append('ALTER COLUMN %s TO %s' % (col_name, new_name)) |
|
54 | 54 | |
|
55 | 55 | def _visit_column_nullable(self, table, column, delta): |
|
56 | 56 | """Changing NULL is not supported""" |
|
57 | 57 | # TODO: http://www.firebirdfaq.org/faq103/ |
|
58 | 58 | raise exceptions.NotSupportedError( |
|
59 | 59 | "Firebird does not support altering NULL bevahior.") |
|
60 | 60 | |
|
61 | 61 | |
|
62 | 62 | class FBConstraintGenerator(ansisql.ANSIConstraintGenerator): |
|
63 | 63 | """Firebird constraint generator implementation.""" |
|
64 | 64 | |
|
65 | 65 | |
|
66 | 66 | class FBConstraintDropper(ansisql.ANSIConstraintDropper): |
|
67 | 67 | """Firebird constaint dropper implementation.""" |
|
68 | 68 | |
|
69 | 69 | def cascade_constraint(self, constraint): |
|
70 | 70 | """Cascading constraints is not supported""" |
|
71 | 71 | raise exceptions.NotSupportedError( |
|
72 | 72 | "Firebird does not support cascading constraints") |
|
73 | 73 | |
|
74 | 74 | |
|
75 | 75 | class FBDialect(ansisql.ANSIDialect): |
|
76 | 76 | columngenerator = FBColumnGenerator |
|
77 | 77 | columndropper = FBColumnDropper |
|
78 | 78 | schemachanger = FBSchemaChanger |
|
79 | 79 | constraintgenerator = FBConstraintGenerator |
|
80 | 80 | constraintdropper = FBConstraintDropper |
@@ -1,94 +1,94 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | MySQL database specific implementations of changeset classes. |
|
3 | 3 | """ |
|
4 | 4 | |
|
5 | 5 | from sqlalchemy.databases import mysql as sa_base |
|
6 | 6 | from sqlalchemy import types as sqltypes |
|
7 | 7 | |
|
8 | from migrate import exceptions | |
|
9 | from migrate.changeset import ansisql, SQLA_06 | |
|
8 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
9 | from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06 | |
|
10 | 10 | |
|
11 | 11 | |
|
12 | 12 | if not SQLA_06: |
|
13 | 13 | MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator |
|
14 | 14 | else: |
|
15 | 15 | MySQLSchemaGenerator = sa_base.MySQLDDLCompiler |
|
16 | 16 | |
|
17 | 17 | class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator): |
|
18 | 18 | pass |
|
19 | 19 | |
|
20 | 20 | |
|
21 | 21 | class MySQLColumnDropper(ansisql.ANSIColumnDropper): |
|
22 | 22 | pass |
|
23 | 23 | |
|
24 | 24 | |
|
25 | 25 | class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger): |
|
26 | 26 | |
|
27 | 27 | def visit_column(self, delta): |
|
28 | 28 | table = delta.table |
|
29 | 29 | colspec = self.get_column_specification(delta.result_column) |
|
30 | 30 | if delta.result_column.autoincrement: |
|
31 | 31 | primary_keys = [c for c in table.primary_key.columns |
|
32 | 32 | if (c.autoincrement and |
|
33 | 33 | isinstance(c.type, sqltypes.Integer) and |
|
34 | 34 | not c.foreign_keys)] |
|
35 | 35 | |
|
36 | 36 | if primary_keys: |
|
37 | 37 | first = primary_keys.pop(0) |
|
38 | 38 | if first.name == delta.current_name: |
|
39 | 39 | colspec += " AUTO_INCREMENT" |
|
40 | 40 | old_col_name = self.preparer.quote(delta.current_name, table.quote) |
|
41 | 41 | |
|
42 | 42 | self.start_alter_table(table) |
|
43 | 43 | |
|
44 | 44 | self.append("CHANGE COLUMN %s " % old_col_name) |
|
45 | 45 | self.append(colspec) |
|
46 | 46 | self.execute() |
|
47 | 47 | |
|
48 | 48 | def visit_index(self, param): |
|
49 | 49 | # If MySQL can do this, I can't find how |
|
50 | 50 | raise exceptions.NotSupportedError("MySQL cannot rename indexes") |
|
51 | 51 | |
|
52 | 52 | |
|
53 | 53 | class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator): |
|
54 | 54 | pass |
|
55 | 55 | |
|
56 | 56 | if SQLA_06: |
|
57 | 57 | class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper): |
|
58 | 58 | def visit_migrate_check_constraint(self, *p, **k): |
|
59 | 59 | raise exceptions.NotSupportedError("MySQL does not support CHECK" |
|
60 | 60 | " constraints, use triggers instead.") |
|
61 | 61 | |
|
62 | 62 | else: |
|
63 | 63 | class MySQLConstraintDropper(ansisql.ANSIConstraintDropper): |
|
64 | 64 | |
|
65 | 65 | def visit_migrate_primary_key_constraint(self, constraint): |
|
66 | 66 | self.start_alter_table(constraint) |
|
67 | 67 | self.append("DROP PRIMARY KEY") |
|
68 | 68 | self.execute() |
|
69 | 69 | |
|
70 | 70 | def visit_migrate_foreign_key_constraint(self, constraint): |
|
71 | 71 | self.start_alter_table(constraint) |
|
72 | 72 | self.append("DROP FOREIGN KEY ") |
|
73 | 73 | constraint.name = self.get_constraint_name(constraint) |
|
74 | 74 | self.append(self.preparer.format_constraint(constraint)) |
|
75 | 75 | self.execute() |
|
76 | 76 | |
|
77 | 77 | def visit_migrate_check_constraint(self, *p, **k): |
|
78 | 78 | raise exceptions.NotSupportedError("MySQL does not support CHECK" |
|
79 | 79 | " constraints, use triggers instead.") |
|
80 | 80 | |
|
81 | 81 | def visit_migrate_unique_constraint(self, constraint, *p, **k): |
|
82 | 82 | self.start_alter_table(constraint) |
|
83 | 83 | self.append('DROP INDEX ') |
|
84 | 84 | constraint.name = self.get_constraint_name(constraint) |
|
85 | 85 | self.append(self.preparer.format_constraint(constraint)) |
|
86 | 86 | self.execute() |
|
87 | 87 | |
|
88 | 88 | |
|
89 | 89 | class MySQLDialect(ansisql.ANSIDialect): |
|
90 | 90 | columngenerator = MySQLColumnGenerator |
|
91 | 91 | columndropper = MySQLColumnDropper |
|
92 | 92 | schemachanger = MySQLSchemaChanger |
|
93 | 93 | constraintgenerator = MySQLConstraintGenerator |
|
94 | 94 | constraintdropper = MySQLConstraintDropper |
@@ -1,111 +1,111 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | Oracle database specific implementations of changeset classes. |
|
3 | 3 | """ |
|
4 | 4 | import sqlalchemy as sa |
|
5 | 5 | from sqlalchemy.databases import oracle as sa_base |
|
6 | 6 | |
|
7 | from migrate import exceptions | |
|
8 | from migrate.changeset import ansisql, SQLA_06 | |
|
7 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
8 | from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06 | |
|
9 | 9 | |
|
10 | 10 | |
|
11 | 11 | if not SQLA_06: |
|
12 | 12 | OracleSchemaGenerator = sa_base.OracleSchemaGenerator |
|
13 | 13 | else: |
|
14 | 14 | OracleSchemaGenerator = sa_base.OracleDDLCompiler |
|
15 | 15 | |
|
16 | 16 | |
|
17 | 17 | class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator): |
|
18 | 18 | pass |
|
19 | 19 | |
|
20 | 20 | |
|
21 | 21 | class OracleColumnDropper(ansisql.ANSIColumnDropper): |
|
22 | 22 | pass |
|
23 | 23 | |
|
24 | 24 | |
|
25 | 25 | class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger): |
|
26 | 26 | |
|
27 | 27 | def get_column_specification(self, column, **kwargs): |
|
28 | 28 | # Ignore the NOT NULL generated |
|
29 | 29 | override_nullable = kwargs.pop('override_nullable', None) |
|
30 | 30 | if override_nullable: |
|
31 | 31 | orig = column.nullable |
|
32 | 32 | column.nullable = True |
|
33 | 33 | ret = super(OracleSchemaChanger, self).get_column_specification( |
|
34 | 34 | column, **kwargs) |
|
35 | 35 | if override_nullable: |
|
36 | 36 | column.nullable = orig |
|
37 | 37 | return ret |
|
38 | 38 | |
|
39 | 39 | def visit_column(self, delta): |
|
40 | 40 | keys = delta.keys() |
|
41 | 41 | |
|
42 | 42 | if 'name' in keys: |
|
43 | 43 | self._run_subvisit(delta, |
|
44 | 44 | self._visit_column_name, |
|
45 | 45 | start_alter=False) |
|
46 | 46 | |
|
47 | 47 | if len(set(('type', 'nullable', 'server_default')).intersection(keys)): |
|
48 | 48 | self._run_subvisit(delta, |
|
49 | 49 | self._visit_column_change, |
|
50 | 50 | start_alter=False) |
|
51 | 51 | |
|
52 | 52 | def _visit_column_change(self, table, column, delta): |
|
53 | 53 | # Oracle cannot drop a default once created, but it can set it |
|
54 | 54 | # to null. We'll do that if default=None |
|
55 | 55 | # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234 |
|
56 | 56 | dropdefault_hack = (column.server_default is None \ |
|
57 | 57 | and 'server_default' in delta.keys()) |
|
58 | 58 | # Oracle apparently doesn't like it when we say "not null" if |
|
59 | 59 | # the column's already not null. Fudge it, so we don't need a |
|
60 | 60 | # new function |
|
61 | 61 | notnull_hack = ((not column.nullable) \ |
|
62 | 62 | and ('nullable' not in delta.keys())) |
|
63 | 63 | # We need to specify NULL if we're removing a NOT NULL |
|
64 | 64 | # constraint |
|
65 | 65 | null_hack = (column.nullable and ('nullable' in delta.keys())) |
|
66 | 66 | |
|
67 | 67 | if dropdefault_hack: |
|
68 | 68 | column.server_default = sa.PassiveDefault(sa.sql.null()) |
|
69 | 69 | if notnull_hack: |
|
70 | 70 | column.nullable = True |
|
71 | 71 | colspec = self.get_column_specification(column, |
|
72 | 72 | override_nullable=null_hack) |
|
73 | 73 | if null_hack: |
|
74 | 74 | colspec += ' NULL' |
|
75 | 75 | if notnull_hack: |
|
76 | 76 | column.nullable = False |
|
77 | 77 | if dropdefault_hack: |
|
78 | 78 | column.server_default = None |
|
79 | 79 | |
|
80 | 80 | self.start_alter_table(table) |
|
81 | 81 | self.append("MODIFY (") |
|
82 | 82 | self.append(colspec) |
|
83 | 83 | self.append(")") |
|
84 | 84 | |
|
85 | 85 | |
|
86 | 86 | class OracleConstraintCommon(object): |
|
87 | 87 | |
|
88 | 88 | def get_constraint_name(self, cons): |
|
89 | 89 | # Oracle constraints can't guess their name like other DBs |
|
90 | 90 | if not cons.name: |
|
91 | 91 | raise exceptions.NotSupportedError( |
|
92 | 92 | "Oracle constraint names must be explicitly stated") |
|
93 | 93 | return cons.name |
|
94 | 94 | |
|
95 | 95 | |
|
96 | 96 | class OracleConstraintGenerator(OracleConstraintCommon, |
|
97 | 97 | ansisql.ANSIConstraintGenerator): |
|
98 | 98 | pass |
|
99 | 99 | |
|
100 | 100 | |
|
101 | 101 | class OracleConstraintDropper(OracleConstraintCommon, |
|
102 | 102 | ansisql.ANSIConstraintDropper): |
|
103 | 103 | pass |
|
104 | 104 | |
|
105 | 105 | |
|
106 | 106 | class OracleDialect(ansisql.ANSIDialect): |
|
107 | 107 | columngenerator = OracleColumnGenerator |
|
108 | 108 | columndropper = OracleColumnDropper |
|
109 | 109 | schemachanger = OracleSchemaChanger |
|
110 | 110 | constraintgenerator = OracleConstraintGenerator |
|
111 | 111 | constraintdropper = OracleConstraintDropper |
@@ -1,46 +1,46 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | `PostgreSQL`_ database specific implementations of changeset classes. |
|
3 | 3 | |
|
4 | 4 | .. _`PostgreSQL`: http://www.postgresql.org/ |
|
5 | 5 | """ |
|
6 | from migrate.changeset import ansisql, SQLA_06 | |
|
6 | from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06 | |
|
7 | 7 | |
|
8 | 8 | if not SQLA_06: |
|
9 | 9 | from sqlalchemy.databases import postgres as sa_base |
|
10 | 10 | PGSchemaGenerator = sa_base.PGSchemaGenerator |
|
11 | 11 | else: |
|
12 | 12 | from sqlalchemy.databases import postgresql as sa_base |
|
13 | 13 | PGSchemaGenerator = sa_base.PGDDLCompiler |
|
14 | 14 | |
|
15 | 15 | |
|
16 | 16 | class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator): |
|
17 | 17 | """PostgreSQL column generator implementation.""" |
|
18 | 18 | pass |
|
19 | 19 | |
|
20 | 20 | |
|
21 | 21 | class PGColumnDropper(ansisql.ANSIColumnDropper): |
|
22 | 22 | """PostgreSQL column dropper implementation.""" |
|
23 | 23 | pass |
|
24 | 24 | |
|
25 | 25 | |
|
26 | 26 | class PGSchemaChanger(ansisql.ANSISchemaChanger): |
|
27 | 27 | """PostgreSQL schema changer implementation.""" |
|
28 | 28 | pass |
|
29 | 29 | |
|
30 | 30 | |
|
31 | 31 | class PGConstraintGenerator(ansisql.ANSIConstraintGenerator): |
|
32 | 32 | """PostgreSQL constraint generator implementation.""" |
|
33 | 33 | pass |
|
34 | 34 | |
|
35 | 35 | |
|
36 | 36 | class PGConstraintDropper(ansisql.ANSIConstraintDropper): |
|
37 | 37 | """PostgreSQL constaint dropper implementation.""" |
|
38 | 38 | pass |
|
39 | 39 | |
|
40 | 40 | |
|
41 | 41 | class PGDialect(ansisql.ANSIDialect): |
|
42 | 42 | columngenerator = PGColumnGenerator |
|
43 | 43 | columndropper = PGColumnDropper |
|
44 | 44 | schemachanger = PGSchemaChanger |
|
45 | 45 | constraintgenerator = PGConstraintGenerator |
|
46 | 46 | constraintdropper = PGConstraintDropper |
@@ -1,148 +1,148 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | `SQLite`_ database specific implementations of changeset classes. |
|
3 | 3 | |
|
4 | 4 | .. _`SQLite`: http://www.sqlite.org/ |
|
5 | 5 | """ |
|
6 | 6 | from UserDict import DictMixin |
|
7 | 7 | from copy import copy |
|
8 | 8 | |
|
9 | 9 | from sqlalchemy.databases import sqlite as sa_base |
|
10 | 10 | |
|
11 | from migrate import exceptions | |
|
12 | from migrate.changeset import ansisql, SQLA_06 | |
|
11 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
12 | from rhodecode.lib.dbmigrate.migrate.changeset import ansisql, SQLA_06 | |
|
13 | 13 | |
|
14 | 14 | |
|
15 | 15 | if not SQLA_06: |
|
16 | 16 | SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator |
|
17 | 17 | else: |
|
18 | 18 | SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler |
|
19 | 19 | |
|
20 | 20 | class SQLiteCommon(object): |
|
21 | 21 | |
|
22 | 22 | def _not_supported(self, op): |
|
23 | 23 | raise exceptions.NotSupportedError("SQLite does not support " |
|
24 | 24 | "%s; see http://www.sqlite.org/lang_altertable.html" % op) |
|
25 | 25 | |
|
26 | 26 | |
|
27 | 27 | class SQLiteHelper(SQLiteCommon): |
|
28 | 28 | |
|
29 | 29 | def recreate_table(self,table,column=None,delta=None): |
|
30 | 30 | table_name = self.preparer.format_table(table) |
|
31 | 31 | |
|
32 | 32 | # we remove all indexes so as not to have |
|
33 | 33 | # problems during copy and re-create |
|
34 | 34 | for index in table.indexes: |
|
35 | 35 | index.drop() |
|
36 | 36 | |
|
37 | 37 | self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name) |
|
38 | 38 | self.execute() |
|
39 | 39 | |
|
40 | 40 | insertion_string = self._modify_table(table, column, delta) |
|
41 | 41 | |
|
42 | 42 | table.create() |
|
43 | 43 | self.append(insertion_string % {'table_name': table_name}) |
|
44 | 44 | self.execute() |
|
45 | 45 | self.append('DROP TABLE migration_tmp') |
|
46 | 46 | self.execute() |
|
47 | 47 | |
|
48 | 48 | def visit_column(self, delta): |
|
49 | 49 | if isinstance(delta, DictMixin): |
|
50 | 50 | column = delta.result_column |
|
51 | 51 | table = self._to_table(delta.table) |
|
52 | 52 | else: |
|
53 | 53 | column = delta |
|
54 | 54 | table = self._to_table(column.table) |
|
55 | 55 | self.recreate_table(table,column,delta) |
|
56 | 56 | |
|
57 | 57 | class SQLiteColumnGenerator(SQLiteSchemaGenerator, |
|
58 | 58 | ansisql.ANSIColumnGenerator, |
|
59 | 59 | # at the end so we get the normal |
|
60 | 60 | # visit_column by default |
|
61 | 61 | SQLiteHelper, |
|
62 | 62 | SQLiteCommon |
|
63 | 63 | ): |
|
64 | 64 | """SQLite ColumnGenerator""" |
|
65 | 65 | |
|
66 | 66 | def _modify_table(self, table, column, delta): |
|
67 | 67 | columns = ' ,'.join(map( |
|
68 | 68 | self.preparer.format_column, |
|
69 | 69 | [c for c in table.columns if c.name!=column.name])) |
|
70 | 70 | return ('INSERT INTO %%(table_name)s (%(cols)s) ' |
|
71 | 71 | 'SELECT %(cols)s from migration_tmp')%{'cols':columns} |
|
72 | 72 | |
|
73 | 73 | def visit_column(self,column): |
|
74 | 74 | if column.foreign_keys: |
|
75 | 75 | SQLiteHelper.visit_column(self,column) |
|
76 | 76 | else: |
|
77 | 77 | super(SQLiteColumnGenerator,self).visit_column(column) |
|
78 | 78 | |
|
79 | 79 | class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper): |
|
80 | 80 | """SQLite ColumnDropper""" |
|
81 | 81 | |
|
82 | 82 | def _modify_table(self, table, column, delta): |
|
83 | 83 | columns = ' ,'.join(map(self.preparer.format_column, table.columns)) |
|
84 | 84 | return 'INSERT INTO %(table_name)s SELECT ' + columns + \ |
|
85 | 85 | ' from migration_tmp' |
|
86 | 86 | |
|
87 | 87 | |
|
88 | 88 | class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger): |
|
89 | 89 | """SQLite SchemaChanger""" |
|
90 | 90 | |
|
91 | 91 | def _modify_table(self, table, column, delta): |
|
92 | 92 | return 'INSERT INTO %(table_name)s SELECT * from migration_tmp' |
|
93 | 93 | |
|
94 | 94 | def visit_index(self, index): |
|
95 | 95 | """Does not support ALTER INDEX""" |
|
96 | 96 | self._not_supported('ALTER INDEX') |
|
97 | 97 | |
|
98 | 98 | |
|
99 | 99 | class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon): |
|
100 | 100 | |
|
101 | 101 | def visit_migrate_primary_key_constraint(self, constraint): |
|
102 | 102 | tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )" |
|
103 | 103 | cols = ', '.join(map(self.preparer.format_column, constraint.columns)) |
|
104 | 104 | tname = self.preparer.format_table(constraint.table) |
|
105 | 105 | name = self.get_constraint_name(constraint) |
|
106 | 106 | msg = tmpl % (name, tname, cols) |
|
107 | 107 | self.append(msg) |
|
108 | 108 | self.execute() |
|
109 | 109 | |
|
110 | 110 | def _modify_table(self, table, column, delta): |
|
111 | 111 | return 'INSERT INTO %(table_name)s SELECT * from migration_tmp' |
|
112 | 112 | |
|
113 | 113 | def visit_migrate_foreign_key_constraint(self, *p, **k): |
|
114 | 114 | self.recreate_table(p[0].table) |
|
115 | 115 | |
|
116 | 116 | def visit_migrate_unique_constraint(self, *p, **k): |
|
117 | 117 | self.recreate_table(p[0].table) |
|
118 | 118 | |
|
119 | 119 | |
|
120 | 120 | class SQLiteConstraintDropper(ansisql.ANSIColumnDropper, |
|
121 | 121 | SQLiteCommon, |
|
122 | 122 | ansisql.ANSIConstraintCommon): |
|
123 | 123 | |
|
124 | 124 | def visit_migrate_primary_key_constraint(self, constraint): |
|
125 | 125 | tmpl = "DROP INDEX %s " |
|
126 | 126 | name = self.get_constraint_name(constraint) |
|
127 | 127 | msg = tmpl % (name) |
|
128 | 128 | self.append(msg) |
|
129 | 129 | self.execute() |
|
130 | 130 | |
|
131 | 131 | def visit_migrate_foreign_key_constraint(self, *p, **k): |
|
132 | 132 | self._not_supported('ALTER TABLE DROP CONSTRAINT') |
|
133 | 133 | |
|
134 | 134 | def visit_migrate_check_constraint(self, *p, **k): |
|
135 | 135 | self._not_supported('ALTER TABLE DROP CONSTRAINT') |
|
136 | 136 | |
|
137 | 137 | def visit_migrate_unique_constraint(self, *p, **k): |
|
138 | 138 | self._not_supported('ALTER TABLE DROP CONSTRAINT') |
|
139 | 139 | |
|
140 | 140 | |
|
141 | 141 | # TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index |
|
142 | 142 | |
|
143 | 143 | class SQLiteDialect(ansisql.ANSIDialect): |
|
144 | 144 | columngenerator = SQLiteColumnGenerator |
|
145 | 145 | columndropper = SQLiteColumnDropper |
|
146 | 146 | schemachanger = SQLiteSchemaChanger |
|
147 | 147 | constraintgenerator = SQLiteConstraintGenerator |
|
148 | 148 | constraintdropper = SQLiteConstraintDropper |
@@ -1,78 +1,78 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | Module for visitor class mapping. |
|
3 | 3 | """ |
|
4 | 4 | import sqlalchemy as sa |
|
5 | 5 | |
|
6 | from migrate.changeset import ansisql | |
|
7 | from migrate.changeset.databases import (sqlite, | |
|
6 | from rhodecode.lib.dbmigrate.migrate.changeset import ansisql | |
|
7 | from rhodecode.lib.dbmigrate.migrate.changeset.databases import (sqlite, | |
|
8 | 8 | postgres, |
|
9 | 9 | mysql, |
|
10 | 10 | oracle, |
|
11 | 11 | firebird) |
|
12 | 12 | |
|
13 | 13 | |
|
14 | 14 | # Map SA dialects to the corresponding Migrate extensions |
|
15 | 15 | DIALECTS = { |
|
16 | 16 | "default": ansisql.ANSIDialect, |
|
17 | 17 | "sqlite": sqlite.SQLiteDialect, |
|
18 | 18 | "postgres": postgres.PGDialect, |
|
19 | 19 | "postgresql": postgres.PGDialect, |
|
20 | 20 | "mysql": mysql.MySQLDialect, |
|
21 | 21 | "oracle": oracle.OracleDialect, |
|
22 | 22 | "firebird": firebird.FBDialect, |
|
23 | 23 | } |
|
24 | 24 | |
|
25 | 25 | |
|
26 | 26 | def get_engine_visitor(engine, name): |
|
27 | 27 | """ |
|
28 | 28 | Get the visitor implementation for the given database engine. |
|
29 | 29 | |
|
30 | 30 | :param engine: SQLAlchemy Engine |
|
31 | 31 | :param name: Name of the visitor |
|
32 | 32 | :type name: string |
|
33 | 33 | :type engine: Engine |
|
34 | 34 | :returns: visitor |
|
35 | 35 | """ |
|
36 | 36 | # TODO: link to supported visitors |
|
37 | 37 | return get_dialect_visitor(engine.dialect, name) |
|
38 | 38 | |
|
39 | 39 | |
|
40 | 40 | def get_dialect_visitor(sa_dialect, name): |
|
41 | 41 | """ |
|
42 | 42 | Get the visitor implementation for the given dialect. |
|
43 | 43 | |
|
44 | 44 | Finds the visitor implementation based on the dialect class and |
|
45 | 45 | returns and instance initialized with the given name. |
|
46 | 46 | |
|
47 | 47 | Binds dialect specific preparer to visitor. |
|
48 | 48 | """ |
|
49 | 49 | |
|
50 | 50 | # map sa dialect to migrate dialect and return visitor |
|
51 | 51 | sa_dialect_name = getattr(sa_dialect, 'name', 'default') |
|
52 | 52 | migrate_dialect_cls = DIALECTS[sa_dialect_name] |
|
53 | 53 | visitor = getattr(migrate_dialect_cls, name) |
|
54 | 54 | |
|
55 | 55 | # bind preparer |
|
56 | 56 | visitor.preparer = sa_dialect.preparer(sa_dialect) |
|
57 | 57 | |
|
58 | 58 | return visitor |
|
59 | 59 | |
|
60 | 60 | def run_single_visitor(engine, visitorcallable, element, |
|
61 | 61 | connection=None, **kwargs): |
|
62 | 62 | """Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor` |
|
63 | 63 | with support for migrate visitors. |
|
64 | 64 | """ |
|
65 | 65 | if connection is None: |
|
66 | 66 | conn = engine.contextual_connect(close_with_result=False) |
|
67 | 67 | else: |
|
68 | 68 | conn = connection |
|
69 | 69 | visitor = visitorcallable(engine.dialect, conn) |
|
70 | 70 | try: |
|
71 | 71 | if hasattr(element, '__migrate_visit_name__'): |
|
72 | 72 | fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__) |
|
73 | 73 | else: |
|
74 | 74 | fn = getattr(visitor, 'visit_' + element.__visit_name__) |
|
75 | 75 | fn(element, **kwargs) |
|
76 | 76 | finally: |
|
77 | 77 | if connection is None: |
|
78 | 78 | conn.close() |
@@ -1,669 +1,669 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | Schema module providing common schema operations. |
|
3 | 3 | """ |
|
4 | 4 | import warnings |
|
5 | 5 | |
|
6 | 6 | from UserDict import DictMixin |
|
7 | 7 | |
|
8 | 8 | import sqlalchemy |
|
9 | 9 | |
|
10 | 10 | from sqlalchemy.schema import ForeignKeyConstraint |
|
11 | 11 | from sqlalchemy.schema import UniqueConstraint |
|
12 | 12 | |
|
13 | from migrate.exceptions import * | |
|
14 | from migrate.changeset import SQLA_06 | |
|
15 | from migrate.changeset.databases.visitor import (get_engine_visitor, | |
|
13 | from rhodecode.lib.dbmigrate.migrate.exceptions import * | |
|
14 | from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06 | |
|
15 | from rhodecode.lib.dbmigrate.migrate.changeset.databases.visitor import (get_engine_visitor, | |
|
16 | 16 | run_single_visitor) |
|
17 | 17 | |
|
18 | 18 | |
|
19 | 19 | __all__ = [ |
|
20 | 20 | 'create_column', |
|
21 | 21 | 'drop_column', |
|
22 | 22 | 'alter_column', |
|
23 | 23 | 'rename_table', |
|
24 | 24 | 'rename_index', |
|
25 | 25 | 'ChangesetTable', |
|
26 | 26 | 'ChangesetColumn', |
|
27 | 27 | 'ChangesetIndex', |
|
28 | 28 | 'ChangesetDefaultClause', |
|
29 | 29 | 'ColumnDelta', |
|
30 | 30 | ] |
|
31 | 31 | |
|
32 | 32 | DEFAULT_ALTER_METADATA = True |
|
33 | 33 | |
|
34 | 34 | |
|
35 | 35 | def create_column(column, table=None, *p, **kw): |
|
36 | 36 | """Create a column, given the table. |
|
37 | 37 | |
|
38 | 38 | API to :meth:`ChangesetColumn.create`. |
|
39 | 39 | """ |
|
40 | 40 | if table is not None: |
|
41 | 41 | return table.create_column(column, *p, **kw) |
|
42 | 42 | return column.create(*p, **kw) |
|
43 | 43 | |
|
44 | 44 | |
|
45 | 45 | def drop_column(column, table=None, *p, **kw): |
|
46 | 46 | """Drop a column, given the table. |
|
47 | 47 | |
|
48 | 48 | API to :meth:`ChangesetColumn.drop`. |
|
49 | 49 | """ |
|
50 | 50 | if table is not None: |
|
51 | 51 | return table.drop_column(column, *p, **kw) |
|
52 | 52 | return column.drop(*p, **kw) |
|
53 | 53 | |
|
54 | 54 | |
|
55 | 55 | def rename_table(table, name, engine=None, **kw): |
|
56 | 56 | """Rename a table. |
|
57 | 57 | |
|
58 | 58 | If Table instance is given, engine is not used. |
|
59 | 59 | |
|
60 | 60 | API to :meth:`ChangesetTable.rename`. |
|
61 | 61 | |
|
62 | 62 | :param table: Table to be renamed. |
|
63 | 63 | :param name: New name for Table. |
|
64 | 64 | :param engine: Engine instance. |
|
65 | 65 | :type table: string or Table instance |
|
66 | 66 | :type name: string |
|
67 | 67 | :type engine: obj |
|
68 | 68 | """ |
|
69 | 69 | table = _to_table(table, engine) |
|
70 | 70 | table.rename(name, **kw) |
|
71 | 71 | |
|
72 | 72 | |
|
73 | 73 | def rename_index(index, name, table=None, engine=None, **kw): |
|
74 | 74 | """Rename an index. |
|
75 | 75 | |
|
76 | 76 | If Index instance is given, |
|
77 | 77 | table and engine are not used. |
|
78 | 78 | |
|
79 | 79 | API to :meth:`ChangesetIndex.rename`. |
|
80 | 80 | |
|
81 | 81 | :param index: Index to be renamed. |
|
82 | 82 | :param name: New name for index. |
|
83 | 83 | :param table: Table to which Index is reffered. |
|
84 | 84 | :param engine: Engine instance. |
|
85 | 85 | :type index: string or Index instance |
|
86 | 86 | :type name: string |
|
87 | 87 | :type table: string or Table instance |
|
88 | 88 | :type engine: obj |
|
89 | 89 | """ |
|
90 | 90 | index = _to_index(index, table, engine) |
|
91 | 91 | index.rename(name, **kw) |
|
92 | 92 | |
|
93 | 93 | |
|
94 | 94 | def alter_column(*p, **k): |
|
95 | 95 | """Alter a column. |
|
96 | 96 | |
|
97 | 97 | This is a helper function that creates a :class:`ColumnDelta` and |
|
98 | 98 | runs it. |
|
99 | 99 | |
|
100 | 100 | :argument column: |
|
101 | 101 | The name of the column to be altered or a |
|
102 | 102 | :class:`ChangesetColumn` column representing it. |
|
103 | 103 | |
|
104 | 104 | :param table: |
|
105 | 105 | A :class:`~sqlalchemy.schema.Table` or table name to |
|
106 | 106 | for the table where the column will be changed. |
|
107 | 107 | |
|
108 | 108 | :param engine: |
|
109 | 109 | The :class:`~sqlalchemy.engine.base.Engine` to use for table |
|
110 | 110 | reflection and schema alterations. |
|
111 | 111 | |
|
112 | 112 | :param alter_metadata: |
|
113 | 113 | If `True`, which is the default, the |
|
114 | 114 | :class:`~sqlalchemy.schema.Column` will also modified. |
|
115 | 115 | If `False`, the :class:`~sqlalchemy.schema.Column` will be left |
|
116 | 116 | as it was. |
|
117 | 117 | |
|
118 | 118 | :returns: A :class:`ColumnDelta` instance representing the change. |
|
119 | 119 | |
|
120 | 120 | |
|
121 | 121 | """ |
|
122 | 122 | |
|
123 | 123 | k.setdefault('alter_metadata', DEFAULT_ALTER_METADATA) |
|
124 | 124 | |
|
125 | 125 | if 'table' not in k and isinstance(p[0], sqlalchemy.Column): |
|
126 | 126 | k['table'] = p[0].table |
|
127 | 127 | if 'engine' not in k: |
|
128 | 128 | k['engine'] = k['table'].bind |
|
129 | 129 | |
|
130 | 130 | # deprecation |
|
131 | 131 | if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column): |
|
132 | 132 | warnings.warn( |
|
133 | 133 | "Passing a Column object to alter_column is deprecated." |
|
134 | 134 | " Just pass in keyword parameters instead.", |
|
135 | 135 | MigrateDeprecationWarning |
|
136 | 136 | ) |
|
137 | 137 | engine = k['engine'] |
|
138 | 138 | delta = ColumnDelta(*p, **k) |
|
139 | 139 | |
|
140 | 140 | visitorcallable = get_engine_visitor(engine, 'schemachanger') |
|
141 | 141 | engine._run_visitor(visitorcallable, delta) |
|
142 | 142 | |
|
143 | 143 | return delta |
|
144 | 144 | |
|
145 | 145 | |
|
146 | 146 | def _to_table(table, engine=None): |
|
147 | 147 | """Return if instance of Table, else construct new with metadata""" |
|
148 | 148 | if isinstance(table, sqlalchemy.Table): |
|
149 | 149 | return table |
|
150 | 150 | |
|
151 | 151 | # Given: table name, maybe an engine |
|
152 | 152 | meta = sqlalchemy.MetaData() |
|
153 | 153 | if engine is not None: |
|
154 | 154 | meta.bind = engine |
|
155 | 155 | return sqlalchemy.Table(table, meta) |
|
156 | 156 | |
|
157 | 157 | |
|
158 | 158 | def _to_index(index, table=None, engine=None): |
|
159 | 159 | """Return if instance of Index, else construct new with metadata""" |
|
160 | 160 | if isinstance(index, sqlalchemy.Index): |
|
161 | 161 | return index |
|
162 | 162 | |
|
163 | 163 | # Given: index name; table name required |
|
164 | 164 | table = _to_table(table, engine) |
|
165 | 165 | ret = sqlalchemy.Index(index) |
|
166 | 166 | ret.table = table |
|
167 | 167 | return ret |
|
168 | 168 | |
|
169 | 169 | |
|
170 | 170 | class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem): |
|
171 | 171 | """Extracts the differences between two columns/column-parameters |
|
172 | 172 | |
|
173 | 173 | May receive parameters arranged in several different ways: |
|
174 | 174 | |
|
175 | 175 | * **current_column, new_column, \*p, \*\*kw** |
|
176 | 176 | Additional parameters can be specified to override column |
|
177 | 177 | differences. |
|
178 | 178 | |
|
179 | 179 | * **current_column, \*p, \*\*kw** |
|
180 | 180 | Additional parameters alter current_column. Table name is extracted |
|
181 | 181 | from current_column object. |
|
182 | 182 | Name is changed to current_column.name from current_name, |
|
183 | 183 | if current_name is specified. |
|
184 | 184 | |
|
185 | 185 | * **current_col_name, \*p, \*\*kw** |
|
186 | 186 | Table kw must specified. |
|
187 | 187 | |
|
188 | 188 | :param table: Table at which current Column should be bound to.\ |
|
189 | 189 | If table name is given, reflection will be used. |
|
190 | 190 | :type table: string or Table instance |
|
191 | 191 | :param alter_metadata: If True, it will apply changes to metadata. |
|
192 | 192 | :type alter_metadata: bool |
|
193 | 193 | :param metadata: If `alter_metadata` is true, \ |
|
194 | 194 | metadata is used to reflect table names into |
|
195 | 195 | :type metadata: :class:`MetaData` instance |
|
196 | 196 | :param engine: When reflecting tables, either engine or metadata must \ |
|
197 | 197 | be specified to acquire engine object. |
|
198 | 198 | :type engine: :class:`Engine` instance |
|
199 | 199 | :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \ |
|
200 | 200 | `result_column` through :func:`dict` alike object. |
|
201 | 201 | |
|
202 | 202 | * :class:`ColumnDelta`.result_column is altered column with new attributes |
|
203 | 203 | |
|
204 | 204 | * :class:`ColumnDelta`.current_name is current name of column in db |
|
205 | 205 | |
|
206 | 206 | |
|
207 | 207 | """ |
|
208 | 208 | |
|
209 | 209 | # Column attributes that can be altered |
|
210 | 210 | diff_keys = ('name', 'type', 'primary_key', 'nullable', |
|
211 | 211 | 'server_onupdate', 'server_default', 'autoincrement') |
|
212 | 212 | diffs = dict() |
|
213 | 213 | __visit_name__ = 'column' |
|
214 | 214 | |
|
215 | 215 | def __init__(self, *p, **kw): |
|
216 | 216 | self.alter_metadata = kw.pop("alter_metadata", False) |
|
217 | 217 | self.meta = kw.pop("metadata", None) |
|
218 | 218 | self.engine = kw.pop("engine", None) |
|
219 | 219 | |
|
220 | 220 | # Things are initialized differently depending on how many column |
|
221 | 221 | # parameters are given. Figure out how many and call the appropriate |
|
222 | 222 | # method. |
|
223 | 223 | if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column): |
|
224 | 224 | # At least one column specified |
|
225 | 225 | if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column): |
|
226 | 226 | # Two columns specified |
|
227 | 227 | diffs = self.compare_2_columns(*p, **kw) |
|
228 | 228 | else: |
|
229 | 229 | # Exactly one column specified |
|
230 | 230 | diffs = self.compare_1_column(*p, **kw) |
|
231 | 231 | else: |
|
232 | 232 | # Zero columns specified |
|
233 | 233 | if not len(p) or not isinstance(p[0], basestring): |
|
234 | 234 | raise ValueError("First argument must be column name") |
|
235 | 235 | diffs = self.compare_parameters(*p, **kw) |
|
236 | 236 | |
|
237 | 237 | self.apply_diffs(diffs) |
|
238 | 238 | |
|
239 | 239 | def __repr__(self): |
|
240 | 240 | return '<ColumnDelta altermetadata=%r, %s>' % (self.alter_metadata, |
|
241 | 241 | super(ColumnDelta, self).__repr__()) |
|
242 | 242 | |
|
243 | 243 | def __getitem__(self, key): |
|
244 | 244 | if key not in self.keys(): |
|
245 | 245 | raise KeyError("No such diff key, available: %s" % self.diffs) |
|
246 | 246 | return getattr(self.result_column, key) |
|
247 | 247 | |
|
248 | 248 | def __setitem__(self, key, value): |
|
249 | 249 | if key not in self.keys(): |
|
250 | 250 | raise KeyError("No such diff key, available: %s" % self.diffs) |
|
251 | 251 | setattr(self.result_column, key, value) |
|
252 | 252 | |
|
253 | 253 | def __delitem__(self, key): |
|
254 | 254 | raise NotImplementedError |
|
255 | 255 | |
|
256 | 256 | def keys(self): |
|
257 | 257 | return self.diffs.keys() |
|
258 | 258 | |
|
259 | 259 | def compare_parameters(self, current_name, *p, **k): |
|
260 | 260 | """Compares Column objects with reflection""" |
|
261 | 261 | self.table = k.pop('table') |
|
262 | 262 | self.result_column = self._table.c.get(current_name) |
|
263 | 263 | if len(p): |
|
264 | 264 | k = self._extract_parameters(p, k, self.result_column) |
|
265 | 265 | return k |
|
266 | 266 | |
|
267 | 267 | def compare_1_column(self, col, *p, **k): |
|
268 | 268 | """Compares one Column object""" |
|
269 | 269 | self.table = k.pop('table', None) |
|
270 | 270 | if self.table is None: |
|
271 | 271 | self.table = col.table |
|
272 | 272 | self.result_column = col |
|
273 | 273 | if len(p): |
|
274 | 274 | k = self._extract_parameters(p, k, self.result_column) |
|
275 | 275 | return k |
|
276 | 276 | |
|
277 | 277 | def compare_2_columns(self, old_col, new_col, *p, **k): |
|
278 | 278 | """Compares two Column objects""" |
|
279 | 279 | self.process_column(new_col) |
|
280 | 280 | self.table = k.pop('table', None) |
|
281 | 281 | # we cannot use bool() on table in SA06 |
|
282 | 282 | if self.table is None: |
|
283 | 283 | self.table = old_col.table |
|
284 | 284 | if self.table is None: |
|
285 | 285 | new_col.table |
|
286 | 286 | self.result_column = old_col |
|
287 | 287 | |
|
288 | 288 | # set differences |
|
289 | 289 | # leave out some stuff for later comp |
|
290 | 290 | for key in (set(self.diff_keys) - set(('type',))): |
|
291 | 291 | val = getattr(new_col, key, None) |
|
292 | 292 | if getattr(self.result_column, key, None) != val: |
|
293 | 293 | k.setdefault(key, val) |
|
294 | 294 | |
|
295 | 295 | # inspect types |
|
296 | 296 | if not self.are_column_types_eq(self.result_column.type, new_col.type): |
|
297 | 297 | k.setdefault('type', new_col.type) |
|
298 | 298 | |
|
299 | 299 | if len(p): |
|
300 | 300 | k = self._extract_parameters(p, k, self.result_column) |
|
301 | 301 | return k |
|
302 | 302 | |
|
303 | 303 | def apply_diffs(self, diffs): |
|
304 | 304 | """Populate dict and column object with new values""" |
|
305 | 305 | self.diffs = diffs |
|
306 | 306 | for key in self.diff_keys: |
|
307 | 307 | if key in diffs: |
|
308 | 308 | setattr(self.result_column, key, diffs[key]) |
|
309 | 309 | |
|
310 | 310 | self.process_column(self.result_column) |
|
311 | 311 | |
|
312 | 312 | # create an instance of class type if not yet |
|
313 | 313 | if 'type' in diffs and callable(self.result_column.type): |
|
314 | 314 | self.result_column.type = self.result_column.type() |
|
315 | 315 | |
|
316 | 316 | # add column to the table |
|
317 | 317 | if self.table is not None and self.alter_metadata: |
|
318 | 318 | self.result_column.add_to_table(self.table) |
|
319 | 319 | |
|
320 | 320 | def are_column_types_eq(self, old_type, new_type): |
|
321 | 321 | """Compares two types to be equal""" |
|
322 | 322 | ret = old_type.__class__ == new_type.__class__ |
|
323 | 323 | |
|
324 | 324 | # String length is a special case |
|
325 | 325 | if ret and isinstance(new_type, sqlalchemy.types.String): |
|
326 | 326 | ret = (getattr(old_type, 'length', None) == \ |
|
327 | 327 | getattr(new_type, 'length', None)) |
|
328 | 328 | return ret |
|
329 | 329 | |
|
330 | 330 | def _extract_parameters(self, p, k, column): |
|
331 | 331 | """Extracts data from p and modifies diffs""" |
|
332 | 332 | p = list(p) |
|
333 | 333 | while len(p): |
|
334 | 334 | if isinstance(p[0], basestring): |
|
335 | 335 | k.setdefault('name', p.pop(0)) |
|
336 | 336 | elif isinstance(p[0], sqlalchemy.types.AbstractType): |
|
337 | 337 | k.setdefault('type', p.pop(0)) |
|
338 | 338 | elif callable(p[0]): |
|
339 | 339 | p[0] = p[0]() |
|
340 | 340 | else: |
|
341 | 341 | break |
|
342 | 342 | |
|
343 | 343 | if len(p): |
|
344 | 344 | new_col = column.copy_fixed() |
|
345 | 345 | new_col._init_items(*p) |
|
346 | 346 | k = self.compare_2_columns(column, new_col, **k) |
|
347 | 347 | return k |
|
348 | 348 | |
|
349 | 349 | def process_column(self, column): |
|
350 | 350 | """Processes default values for column""" |
|
351 | 351 | # XXX: this is a snippet from SA processing of positional parameters |
|
352 | 352 | if not SQLA_06 and column.args: |
|
353 | 353 | toinit = list(column.args) |
|
354 | 354 | else: |
|
355 | 355 | toinit = list() |
|
356 | 356 | |
|
357 | 357 | if column.server_default is not None: |
|
358 | 358 | if isinstance(column.server_default, sqlalchemy.FetchedValue): |
|
359 | 359 | toinit.append(column.server_default) |
|
360 | 360 | else: |
|
361 | 361 | toinit.append(sqlalchemy.DefaultClause(column.server_default)) |
|
362 | 362 | if column.server_onupdate is not None: |
|
363 | 363 | if isinstance(column.server_onupdate, FetchedValue): |
|
364 | 364 | toinit.append(column.server_default) |
|
365 | 365 | else: |
|
366 | 366 | toinit.append(sqlalchemy.DefaultClause(column.server_onupdate, |
|
367 | 367 | for_update=True)) |
|
368 | 368 | if toinit: |
|
369 | 369 | column._init_items(*toinit) |
|
370 | 370 | |
|
371 | 371 | if not SQLA_06: |
|
372 | 372 | column.args = [] |
|
373 | 373 | |
|
374 | 374 | def _get_table(self): |
|
375 | 375 | return getattr(self, '_table', None) |
|
376 | 376 | |
|
377 | 377 | def _set_table(self, table): |
|
378 | 378 | if isinstance(table, basestring): |
|
379 | 379 | if self.alter_metadata: |
|
380 | 380 | if not self.meta: |
|
381 | 381 | raise ValueError("metadata must be specified for table" |
|
382 | 382 | " reflection when using alter_metadata") |
|
383 | 383 | meta = self.meta |
|
384 | 384 | if self.engine: |
|
385 | 385 | meta.bind = self.engine |
|
386 | 386 | else: |
|
387 | 387 | if not self.engine and not self.meta: |
|
388 | 388 | raise ValueError("engine or metadata must be specified" |
|
389 | 389 | " to reflect tables") |
|
390 | 390 | if not self.engine: |
|
391 | 391 | self.engine = self.meta.bind |
|
392 | 392 | meta = sqlalchemy.MetaData(bind=self.engine) |
|
393 | 393 | self._table = sqlalchemy.Table(table, meta, autoload=True) |
|
394 | 394 | elif isinstance(table, sqlalchemy.Table): |
|
395 | 395 | self._table = table |
|
396 | 396 | if not self.alter_metadata: |
|
397 | 397 | self._table.meta = sqlalchemy.MetaData(bind=self._table.bind) |
|
398 | 398 | |
|
399 | 399 | def _get_result_column(self): |
|
400 | 400 | return getattr(self, '_result_column', None) |
|
401 | 401 | |
|
402 | 402 | def _set_result_column(self, column): |
|
403 | 403 | """Set Column to Table based on alter_metadata evaluation.""" |
|
404 | 404 | self.process_column(column) |
|
405 | 405 | if not hasattr(self, 'current_name'): |
|
406 | 406 | self.current_name = column.name |
|
407 | 407 | if self.alter_metadata: |
|
408 | 408 | self._result_column = column |
|
409 | 409 | else: |
|
410 | 410 | self._result_column = column.copy_fixed() |
|
411 | 411 | |
|
412 | 412 | table = property(_get_table, _set_table) |
|
413 | 413 | result_column = property(_get_result_column, _set_result_column) |
|
414 | 414 | |
|
415 | 415 | |
|
416 | 416 | class ChangesetTable(object): |
|
417 | 417 | """Changeset extensions to SQLAlchemy tables.""" |
|
418 | 418 | |
|
419 | 419 | def create_column(self, column, *p, **kw): |
|
420 | 420 | """Creates a column. |
|
421 | 421 | |
|
422 | 422 | The column parameter may be a column definition or the name of |
|
423 | 423 | a column in this table. |
|
424 | 424 | |
|
425 | 425 | API to :meth:`ChangesetColumn.create` |
|
426 | 426 | |
|
427 | 427 | :param column: Column to be created |
|
428 | 428 | :type column: Column instance or string |
|
429 | 429 | """ |
|
430 | 430 | if not isinstance(column, sqlalchemy.Column): |
|
431 | 431 | # It's a column name |
|
432 | 432 | column = getattr(self.c, str(column)) |
|
433 | 433 | column.create(table=self, *p, **kw) |
|
434 | 434 | |
|
435 | 435 | def drop_column(self, column, *p, **kw): |
|
436 | 436 | """Drop a column, given its name or definition. |
|
437 | 437 | |
|
438 | 438 | API to :meth:`ChangesetColumn.drop` |
|
439 | 439 | |
|
440 | 440 | :param column: Column to be droped |
|
441 | 441 | :type column: Column instance or string |
|
442 | 442 | """ |
|
443 | 443 | if not isinstance(column, sqlalchemy.Column): |
|
444 | 444 | # It's a column name |
|
445 | 445 | try: |
|
446 | 446 | column = getattr(self.c, str(column)) |
|
447 | 447 | except AttributeError: |
|
448 | 448 | # That column isn't part of the table. We don't need |
|
449 | 449 | # its entire definition to drop the column, just its |
|
450 | 450 | # name, so create a dummy column with the same name. |
|
451 | 451 | column = sqlalchemy.Column(str(column), sqlalchemy.Integer()) |
|
452 | 452 | column.drop(table=self, *p, **kw) |
|
453 | 453 | |
|
454 | 454 | def rename(self, name, connection=None, **kwargs): |
|
455 | 455 | """Rename this table. |
|
456 | 456 | |
|
457 | 457 | :param name: New name of the table. |
|
458 | 458 | :type name: string |
|
459 | 459 | :param alter_metadata: If True, table will be removed from metadata |
|
460 | 460 | :type alter_metadata: bool |
|
461 | 461 | :param connection: reuse connection istead of creating new one. |
|
462 | 462 | :type connection: :class:`sqlalchemy.engine.base.Connection` instance |
|
463 | 463 | """ |
|
464 | 464 | self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA) |
|
465 | 465 | engine = self.bind |
|
466 | 466 | self.new_name = name |
|
467 | 467 | visitorcallable = get_engine_visitor(engine, 'schemachanger') |
|
468 | 468 | run_single_visitor(engine, visitorcallable, self, connection, **kwargs) |
|
469 | 469 | |
|
470 | 470 | # Fix metadata registration |
|
471 | 471 | if self.alter_metadata: |
|
472 | 472 | self.name = name |
|
473 | 473 | self.deregister() |
|
474 | 474 | self._set_parent(self.metadata) |
|
475 | 475 | |
|
476 | 476 | def _meta_key(self): |
|
477 | 477 | return sqlalchemy.schema._get_table_key(self.name, self.schema) |
|
478 | 478 | |
|
479 | 479 | def deregister(self): |
|
480 | 480 | """Remove this table from its metadata""" |
|
481 | 481 | key = self._meta_key() |
|
482 | 482 | meta = self.metadata |
|
483 | 483 | if key in meta.tables: |
|
484 | 484 | del meta.tables[key] |
|
485 | 485 | |
|
486 | 486 | |
|
487 | 487 | class ChangesetColumn(object): |
|
488 | 488 | """Changeset extensions to SQLAlchemy columns.""" |
|
489 | 489 | |
|
490 | 490 | def alter(self, *p, **k): |
|
491 | 491 | """Makes a call to :func:`alter_column` for the column this |
|
492 | 492 | method is called on. |
|
493 | 493 | """ |
|
494 | 494 | if 'table' not in k: |
|
495 | 495 | k['table'] = self.table |
|
496 | 496 | if 'engine' not in k: |
|
497 | 497 | k['engine'] = k['table'].bind |
|
498 | 498 | return alter_column(self, *p, **k) |
|
499 | 499 | |
|
500 | 500 | def create(self, table=None, index_name=None, unique_name=None, |
|
501 | 501 | primary_key_name=None, populate_default=True, connection=None, **kwargs): |
|
502 | 502 | """Create this column in the database. |
|
503 | 503 | |
|
504 | 504 | Assumes the given table exists. ``ALTER TABLE ADD COLUMN``, |
|
505 | 505 | for most databases. |
|
506 | 506 | |
|
507 | 507 | :param table: Table instance to create on. |
|
508 | 508 | :param index_name: Creates :class:`ChangesetIndex` on this column. |
|
509 | 509 | :param unique_name: Creates :class:\ |
|
510 | 510 | `~migrate.changeset.constraint.UniqueConstraint` on this column. |
|
511 | 511 | :param primary_key_name: Creates :class:\ |
|
512 | 512 | `~migrate.changeset.constraint.PrimaryKeyConstraint` on this column. |
|
513 | 513 | :param alter_metadata: If True, column will be added to table object. |
|
514 | 514 | :param populate_default: If True, created column will be \ |
|
515 | 515 | populated with defaults |
|
516 | 516 | :param connection: reuse connection istead of creating new one. |
|
517 | 517 | :type table: Table instance |
|
518 | 518 | :type index_name: string |
|
519 | 519 | :type unique_name: string |
|
520 | 520 | :type primary_key_name: string |
|
521 | 521 | :type alter_metadata: bool |
|
522 | 522 | :type populate_default: bool |
|
523 | 523 | :type connection: :class:`sqlalchemy.engine.base.Connection` instance |
|
524 | 524 | |
|
525 | 525 | :returns: self |
|
526 | 526 | """ |
|
527 | 527 | self.populate_default = populate_default |
|
528 | 528 | self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA) |
|
529 | 529 | self.index_name = index_name |
|
530 | 530 | self.unique_name = unique_name |
|
531 | 531 | self.primary_key_name = primary_key_name |
|
532 | 532 | for cons in ('index_name', 'unique_name', 'primary_key_name'): |
|
533 | 533 | self._check_sanity_constraints(cons) |
|
534 | 534 | |
|
535 | 535 | if self.alter_metadata: |
|
536 | 536 | self.add_to_table(table) |
|
537 | 537 | engine = self.table.bind |
|
538 | 538 | visitorcallable = get_engine_visitor(engine, 'columngenerator') |
|
539 | 539 | engine._run_visitor(visitorcallable, self, connection, **kwargs) |
|
540 | 540 | |
|
541 | 541 | # TODO: reuse existing connection |
|
542 | 542 | if self.populate_default and self.default is not None: |
|
543 | 543 | stmt = table.update().values({self: engine._execute_default(self.default)}) |
|
544 | 544 | engine.execute(stmt) |
|
545 | 545 | |
|
546 | 546 | return self |
|
547 | 547 | |
|
548 | 548 | def drop(self, table=None, connection=None, **kwargs): |
|
549 | 549 | """Drop this column from the database, leaving its table intact. |
|
550 | 550 | |
|
551 | 551 | ``ALTER TABLE DROP COLUMN``, for most databases. |
|
552 | 552 | |
|
553 | 553 | :param alter_metadata: If True, column will be removed from table object. |
|
554 | 554 | :type alter_metadata: bool |
|
555 | 555 | :param connection: reuse connection istead of creating new one. |
|
556 | 556 | :type connection: :class:`sqlalchemy.engine.base.Connection` instance |
|
557 | 557 | """ |
|
558 | 558 | self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA) |
|
559 | 559 | if table is not None: |
|
560 | 560 | self.table = table |
|
561 | 561 | engine = self.table.bind |
|
562 | 562 | if self.alter_metadata: |
|
563 | 563 | self.remove_from_table(self.table, unset_table=False) |
|
564 | 564 | visitorcallable = get_engine_visitor(engine, 'columndropper') |
|
565 | 565 | engine._run_visitor(visitorcallable, self, connection, **kwargs) |
|
566 | 566 | if self.alter_metadata: |
|
567 | 567 | self.table = None |
|
568 | 568 | return self |
|
569 | 569 | |
|
570 | 570 | def add_to_table(self, table): |
|
571 | 571 | if table is not None and self.table is None: |
|
572 | 572 | self._set_parent(table) |
|
573 | 573 | |
|
574 | 574 | def _col_name_in_constraint(self, cons, name): |
|
575 | 575 | return False |
|
576 | 576 | |
|
577 | 577 | def remove_from_table(self, table, unset_table=True): |
|
578 | 578 | # TODO: remove primary keys, constraints, etc |
|
579 | 579 | if unset_table: |
|
580 | 580 | self.table = None |
|
581 | 581 | |
|
582 | 582 | to_drop = set() |
|
583 | 583 | for index in table.indexes: |
|
584 | 584 | columns = [] |
|
585 | 585 | for col in index.columns: |
|
586 | 586 | if col.name != self.name: |
|
587 | 587 | columns.append(col) |
|
588 | 588 | if columns: |
|
589 | 589 | index.columns = columns |
|
590 | 590 | else: |
|
591 | 591 | to_drop.add(index) |
|
592 | 592 | table.indexes = table.indexes - to_drop |
|
593 | 593 | |
|
594 | 594 | to_drop = set() |
|
595 | 595 | for cons in table.constraints: |
|
596 | 596 | # TODO: deal with other types of constraint |
|
597 | 597 | if isinstance(cons, (ForeignKeyConstraint, |
|
598 | 598 | UniqueConstraint)): |
|
599 | 599 | for col_name in cons.columns: |
|
600 | 600 | if not isinstance(col_name, basestring): |
|
601 | 601 | col_name = col_name.name |
|
602 | 602 | if self.name == col_name: |
|
603 | 603 | to_drop.add(cons) |
|
604 | 604 | table.constraints = table.constraints - to_drop |
|
605 | 605 | |
|
606 | 606 | if table.c.contains_column(self): |
|
607 | 607 | table.c.remove(self) |
|
608 | 608 | |
|
609 | 609 | # TODO: this is fixed in 0.6 |
|
610 | 610 | def copy_fixed(self, **kw): |
|
611 | 611 | """Create a copy of this ``Column``, with all attributes.""" |
|
612 | 612 | return sqlalchemy.Column(self.name, self.type, self.default, |
|
613 | 613 | key=self.key, |
|
614 | 614 | primary_key=self.primary_key, |
|
615 | 615 | nullable=self.nullable, |
|
616 | 616 | quote=self.quote, |
|
617 | 617 | index=self.index, |
|
618 | 618 | unique=self.unique, |
|
619 | 619 | onupdate=self.onupdate, |
|
620 | 620 | autoincrement=self.autoincrement, |
|
621 | 621 | server_default=self.server_default, |
|
622 | 622 | server_onupdate=self.server_onupdate, |
|
623 | 623 | *[c.copy(**kw) for c in self.constraints]) |
|
624 | 624 | |
|
625 | 625 | def _check_sanity_constraints(self, name): |
|
626 | 626 | """Check if constraints names are correct""" |
|
627 | 627 | obj = getattr(self, name) |
|
628 | 628 | if (getattr(self, name[:-5]) and not obj): |
|
629 | 629 | raise InvalidConstraintError("Column.create() accepts index_name," |
|
630 | 630 | " primary_key_name and unique_name to generate constraints") |
|
631 | 631 | if not isinstance(obj, basestring) and obj is not None: |
|
632 | 632 | raise InvalidConstraintError( |
|
633 | 633 | "%s argument for column must be constraint name" % name) |
|
634 | 634 | |
|
635 | 635 | |
|
636 | 636 | class ChangesetIndex(object): |
|
637 | 637 | """Changeset extensions to SQLAlchemy Indexes.""" |
|
638 | 638 | |
|
639 | 639 | __visit_name__ = 'index' |
|
640 | 640 | |
|
641 | 641 | def rename(self, name, connection=None, **kwargs): |
|
642 | 642 | """Change the name of an index. |
|
643 | 643 | |
|
644 | 644 | :param name: New name of the Index. |
|
645 | 645 | :type name: string |
|
646 | 646 | :param alter_metadata: If True, Index object will be altered. |
|
647 | 647 | :type alter_metadata: bool |
|
648 | 648 | :param connection: reuse connection istead of creating new one. |
|
649 | 649 | :type connection: :class:`sqlalchemy.engine.base.Connection` instance |
|
650 | 650 | """ |
|
651 | 651 | self.alter_metadata = kwargs.pop('alter_metadata', DEFAULT_ALTER_METADATA) |
|
652 | 652 | engine = self.table.bind |
|
653 | 653 | self.new_name = name |
|
654 | 654 | visitorcallable = get_engine_visitor(engine, 'schemachanger') |
|
655 | 655 | engine._run_visitor(visitorcallable, self, connection, **kwargs) |
|
656 | 656 | if self.alter_metadata: |
|
657 | 657 | self.name = name |
|
658 | 658 | |
|
659 | 659 | |
|
660 | 660 | class ChangesetDefaultClause(object): |
|
661 | 661 | """Implements comparison between :class:`DefaultClause` instances""" |
|
662 | 662 | |
|
663 | 663 | def __eq__(self, other): |
|
664 | 664 | if isinstance(other, self.__class__): |
|
665 | 665 | if self.arg == other.arg: |
|
666 | 666 | return True |
|
667 | 667 | |
|
668 | 668 | def __ne__(self, other): |
|
669 | 669 | return not self.__eq__(other) |
@@ -1,383 +1,383 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | This module provides an external API to the versioning system. |
|
3 | 3 | |
|
4 | 4 | .. versionchanged:: 0.6.0 |
|
5 | 5 | :func:`migrate.versioning.api.test` and schema diff functions |
|
6 | 6 | changed order of positional arguments so all accept `url` and `repository` |
|
7 | 7 | as first arguments. |
|
8 | 8 | |
|
9 | 9 | .. versionchanged:: 0.5.4 |
|
10 | 10 | ``--preview_sql`` displays source file when using SQL scripts. |
|
11 | 11 | If Python script is used, it runs the action with mocked engine and |
|
12 | 12 | returns captured SQL statements. |
|
13 | 13 | |
|
14 | 14 | .. versionchanged:: 0.5.4 |
|
15 | 15 | Deprecated ``--echo`` parameter in favour of new |
|
16 | 16 | :func:`migrate.versioning.util.construct_engine` behavior. |
|
17 | 17 | """ |
|
18 | 18 | |
|
19 | 19 | # Dear migrate developers, |
|
20 | 20 | # |
|
21 | 21 | # please do not comment this module using sphinx syntax because its |
|
22 | 22 | # docstrings are presented as user help and most users cannot |
|
23 | 23 | # interpret sphinx annotated ReStructuredText. |
|
24 | 24 | # |
|
25 | 25 | # Thanks, |
|
26 | 26 | # Jan Dittberner |
|
27 | 27 | |
|
28 | 28 | import sys |
|
29 | 29 | import inspect |
|
30 | 30 | import logging |
|
31 | 31 | |
|
32 | from migrate import exceptions | |
|
33 | from migrate.versioning import (repository, schema, version, | |
|
32 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
33 | from rhodecode.lib.dbmigrate.migrate.versioning import (repository, schema, version, | |
|
34 | 34 | script as script_) # command name conflict |
|
35 | from migrate.versioning.util import catch_known_errors, with_engine | |
|
35 | from rhodecode.lib.dbmigrate.migrate.versioning.util import catch_known_errors, with_engine | |
|
36 | 36 | |
|
37 | 37 | |
|
38 | 38 | log = logging.getLogger(__name__) |
|
39 | 39 | command_desc = { |
|
40 | 40 | 'help': 'displays help on a given command', |
|
41 | 41 | 'create': 'create an empty repository at the specified path', |
|
42 | 42 | 'script': 'create an empty change Python script', |
|
43 | 43 | 'script_sql': 'create empty change SQL scripts for given database', |
|
44 | 44 | 'version': 'display the latest version available in a repository', |
|
45 | 45 | 'db_version': 'show the current version of the repository under version control', |
|
46 | 46 | 'source': 'display the Python code for a particular version in this repository', |
|
47 | 47 | 'version_control': 'mark a database as under this repository\'s version control', |
|
48 | 48 | 'upgrade': 'upgrade a database to a later version', |
|
49 | 49 | 'downgrade': 'downgrade a database to an earlier version', |
|
50 | 50 | 'drop_version_control': 'removes version control from a database', |
|
51 | 51 | 'manage': 'creates a Python script that runs Migrate with a set of default values', |
|
52 | 52 | 'test': 'performs the upgrade and downgrade command on the given database', |
|
53 | 53 | 'compare_model_to_db': 'compare MetaData against the current database state', |
|
54 | 54 | 'create_model': 'dump the current database as a Python model to stdout', |
|
55 | 55 | 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData', |
|
56 | 56 | 'update_db_from_model': 'modify the database to match the structure of the current MetaData', |
|
57 | 57 | } |
|
58 | 58 | __all__ = command_desc.keys() |
|
59 | 59 | |
|
60 | 60 | Repository = repository.Repository |
|
61 | 61 | ControlledSchema = schema.ControlledSchema |
|
62 | 62 | VerNum = version.VerNum |
|
63 | 63 | PythonScript = script_.PythonScript |
|
64 | 64 | SqlScript = script_.SqlScript |
|
65 | 65 | |
|
66 | 66 | |
|
67 | 67 | # deprecated |
|
68 | 68 | def help(cmd=None, **opts): |
|
69 | 69 | """%prog help COMMAND |
|
70 | 70 | |
|
71 | 71 | Displays help on a given command. |
|
72 | 72 | """ |
|
73 | 73 | if cmd is None: |
|
74 | 74 | raise exceptions.UsageError(None) |
|
75 | 75 | try: |
|
76 | 76 | func = globals()[cmd] |
|
77 | 77 | except: |
|
78 | 78 | raise exceptions.UsageError( |
|
79 | 79 | "'%s' isn't a valid command. Try 'help COMMAND'" % cmd) |
|
80 | 80 | ret = func.__doc__ |
|
81 | 81 | if sys.argv[0]: |
|
82 | 82 | ret = ret.replace('%prog', sys.argv[0]) |
|
83 | 83 | return ret |
|
84 | 84 | |
|
85 | 85 | @catch_known_errors |
|
86 | 86 | def create(repository, name, **opts): |
|
87 | 87 | """%prog create REPOSITORY_PATH NAME [--table=TABLE] |
|
88 | 88 | |
|
89 | 89 | Create an empty repository at the specified path. |
|
90 | 90 | |
|
91 | 91 | You can specify the version_table to be used; by default, it is |
|
92 | 92 | 'migrate_version'. This table is created in all version-controlled |
|
93 | 93 | databases. |
|
94 | 94 | """ |
|
95 | 95 | repo_path = Repository.create(repository, name, **opts) |
|
96 | 96 | |
|
97 | 97 | |
|
98 | 98 | @catch_known_errors |
|
99 | 99 | def script(description, repository, **opts): |
|
100 | 100 | """%prog script DESCRIPTION REPOSITORY_PATH |
|
101 | 101 | |
|
102 | 102 | Create an empty change script using the next unused version number |
|
103 | 103 | appended with the given description. |
|
104 | 104 | |
|
105 | 105 | For instance, manage.py script "Add initial tables" creates: |
|
106 | 106 | repository/versions/001_Add_initial_tables.py |
|
107 | 107 | """ |
|
108 | 108 | repo = Repository(repository) |
|
109 | 109 | repo.create_script(description, **opts) |
|
110 | 110 | |
|
111 | 111 | |
|
112 | 112 | @catch_known_errors |
|
113 | 113 | def script_sql(database, repository, **opts): |
|
114 | 114 | """%prog script_sql DATABASE REPOSITORY_PATH |
|
115 | 115 | |
|
116 | 116 | Create empty change SQL scripts for given DATABASE, where DATABASE |
|
117 | 117 | is either specific ('postgres', 'mysql', 'oracle', 'sqlite', etc.) |
|
118 | 118 | or generic ('default'). |
|
119 | 119 | |
|
120 | 120 | For instance, manage.py script_sql postgres creates: |
|
121 | 121 | repository/versions/001_postgres_upgrade.sql and |
|
122 | 122 | repository/versions/001_postgres_postgres.sql |
|
123 | 123 | """ |
|
124 | 124 | repo = Repository(repository) |
|
125 | 125 | repo.create_script_sql(database, **opts) |
|
126 | 126 | |
|
127 | 127 | |
|
128 | 128 | def version(repository, **opts): |
|
129 | 129 | """%prog version REPOSITORY_PATH |
|
130 | 130 | |
|
131 | 131 | Display the latest version available in a repository. |
|
132 | 132 | """ |
|
133 | 133 | repo = Repository(repository) |
|
134 | 134 | return repo.latest |
|
135 | 135 | |
|
136 | 136 | |
|
137 | 137 | @with_engine |
|
138 | 138 | def db_version(url, repository, **opts): |
|
139 | 139 | """%prog db_version URL REPOSITORY_PATH |
|
140 | 140 | |
|
141 | 141 | Show the current version of the repository with the given |
|
142 | 142 | connection string, under version control of the specified |
|
143 | 143 | repository. |
|
144 | 144 | |
|
145 | 145 | The url should be any valid SQLAlchemy connection string. |
|
146 | 146 | """ |
|
147 | 147 | engine = opts.pop('engine') |
|
148 | 148 | schema = ControlledSchema(engine, repository) |
|
149 | 149 | return schema.version |
|
150 | 150 | |
|
151 | 151 | |
|
152 | 152 | def source(version, dest=None, repository=None, **opts): |
|
153 | 153 | """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH |
|
154 | 154 | |
|
155 | 155 | Display the Python code for a particular version in this |
|
156 | 156 | repository. Save it to the file at DESTINATION or, if omitted, |
|
157 | 157 | send to stdout. |
|
158 | 158 | """ |
|
159 | 159 | if repository is None: |
|
160 | 160 | raise exceptions.UsageError("A repository must be specified") |
|
161 | 161 | repo = Repository(repository) |
|
162 | 162 | ret = repo.version(version).script().source() |
|
163 | 163 | if dest is not None: |
|
164 | 164 | dest = open(dest, 'w') |
|
165 | 165 | dest.write(ret) |
|
166 | 166 | dest.close() |
|
167 | 167 | ret = None |
|
168 | 168 | return ret |
|
169 | 169 | |
|
170 | 170 | |
|
171 | 171 | def upgrade(url, repository, version=None, **opts): |
|
172 | 172 | """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql] |
|
173 | 173 | |
|
174 | 174 | Upgrade a database to a later version. |
|
175 | 175 | |
|
176 | 176 | This runs the upgrade() function defined in your change scripts. |
|
177 | 177 | |
|
178 | 178 | By default, the database is updated to the latest available |
|
179 | 179 | version. You may specify a version instead, if you wish. |
|
180 | 180 | |
|
181 | 181 | You may preview the Python or SQL code to be executed, rather than |
|
182 | 182 | actually executing it, using the appropriate 'preview' option. |
|
183 | 183 | """ |
|
184 | 184 | err = "Cannot upgrade a database of version %s to version %s. "\ |
|
185 | 185 | "Try 'downgrade' instead." |
|
186 | 186 | return _migrate(url, repository, version, upgrade=True, err=err, **opts) |
|
187 | 187 | |
|
188 | 188 | |
|
189 | 189 | def downgrade(url, repository, version, **opts): |
|
190 | 190 | """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql] |
|
191 | 191 | |
|
192 | 192 | Downgrade a database to an earlier version. |
|
193 | 193 | |
|
194 | 194 | This is the reverse of upgrade; this runs the downgrade() function |
|
195 | 195 | defined in your change scripts. |
|
196 | 196 | |
|
197 | 197 | You may preview the Python or SQL code to be executed, rather than |
|
198 | 198 | actually executing it, using the appropriate 'preview' option. |
|
199 | 199 | """ |
|
200 | 200 | err = "Cannot downgrade a database of version %s to version %s. "\ |
|
201 | 201 | "Try 'upgrade' instead." |
|
202 | 202 | return _migrate(url, repository, version, upgrade=False, err=err, **opts) |
|
203 | 203 | |
|
204 | 204 | @with_engine |
|
205 | 205 | def test(url, repository, **opts): |
|
206 | 206 | """%prog test URL REPOSITORY_PATH [VERSION] |
|
207 | 207 | |
|
208 | 208 | Performs the upgrade and downgrade option on the given |
|
209 | 209 | database. This is not a real test and may leave the database in a |
|
210 | 210 | bad state. You should therefore better run the test on a copy of |
|
211 | 211 | your database. |
|
212 | 212 | """ |
|
213 | 213 | engine = opts.pop('engine') |
|
214 | 214 | repos = Repository(repository) |
|
215 | 215 | script = repos.version(None).script() |
|
216 | 216 | |
|
217 | 217 | # Upgrade |
|
218 | 218 | log.info("Upgrading...") |
|
219 | 219 | script.run(engine, 1) |
|
220 | 220 | log.info("done") |
|
221 | 221 | |
|
222 | 222 | log.info("Downgrading...") |
|
223 | 223 | script.run(engine, -1) |
|
224 | 224 | log.info("done") |
|
225 | 225 | log.info("Success") |
|
226 | 226 | |
|
227 | 227 | |
|
228 | 228 | @with_engine |
|
229 | 229 | def version_control(url, repository, version=None, **opts): |
|
230 | 230 | """%prog version_control URL REPOSITORY_PATH [VERSION] |
|
231 | 231 | |
|
232 | 232 | Mark a database as under this repository's version control. |
|
233 | 233 | |
|
234 | 234 | Once a database is under version control, schema changes should |
|
235 | 235 | only be done via change scripts in this repository. |
|
236 | 236 | |
|
237 | 237 | This creates the table version_table in the database. |
|
238 | 238 | |
|
239 | 239 | The url should be any valid SQLAlchemy connection string. |
|
240 | 240 | |
|
241 | 241 | By default, the database begins at version 0 and is assumed to be |
|
242 | 242 | empty. If the database is not empty, you may specify a version at |
|
243 | 243 | which to begin instead. No attempt is made to verify this |
|
244 | 244 | version's correctness - the database schema is expected to be |
|
245 | 245 | identical to what it would be if the database were created from |
|
246 | 246 | scratch. |
|
247 | 247 | """ |
|
248 | 248 | engine = opts.pop('engine') |
|
249 | 249 | ControlledSchema.create(engine, repository, version) |
|
250 | 250 | |
|
251 | 251 | |
|
252 | 252 | @with_engine |
|
253 | 253 | def drop_version_control(url, repository, **opts): |
|
254 | 254 | """%prog drop_version_control URL REPOSITORY_PATH |
|
255 | 255 | |
|
256 | 256 | Removes version control from a database. |
|
257 | 257 | """ |
|
258 | 258 | engine = opts.pop('engine') |
|
259 | 259 | schema = ControlledSchema(engine, repository) |
|
260 | 260 | schema.drop() |
|
261 | 261 | |
|
262 | 262 | |
|
263 | 263 | def manage(file, **opts): |
|
264 | 264 | """%prog manage FILENAME [VARIABLES...] |
|
265 | 265 | |
|
266 | 266 | Creates a script that runs Migrate with a set of default values. |
|
267 | 267 | |
|
268 | 268 | For example:: |
|
269 | 269 | |
|
270 | 270 | %prog manage manage.py --repository=/path/to/repository \ |
|
271 | 271 | --url=sqlite:///project.db |
|
272 | 272 | |
|
273 | 273 | would create the script manage.py. The following two commands |
|
274 | 274 | would then have exactly the same results:: |
|
275 | 275 | |
|
276 | 276 | python manage.py version |
|
277 | 277 | %prog version --repository=/path/to/repository |
|
278 | 278 | """ |
|
279 | 279 | Repository.create_manage_file(file, **opts) |
|
280 | 280 | |
|
281 | 281 | |
|
282 | 282 | @with_engine |
|
283 | 283 | def compare_model_to_db(url, repository, model, **opts): |
|
284 | 284 | """%prog compare_model_to_db URL REPOSITORY_PATH MODEL |
|
285 | 285 | |
|
286 | 286 | Compare the current model (assumed to be a module level variable |
|
287 | 287 | of type sqlalchemy.MetaData) against the current database. |
|
288 | 288 | |
|
289 | 289 | NOTE: This is EXPERIMENTAL. |
|
290 | 290 | """ # TODO: get rid of EXPERIMENTAL label |
|
291 | 291 | engine = opts.pop('engine') |
|
292 | 292 | return ControlledSchema.compare_model_to_db(engine, model, repository) |
|
293 | 293 | |
|
294 | 294 | |
|
295 | 295 | @with_engine |
|
296 | 296 | def create_model(url, repository, **opts): |
|
297 | 297 | """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True] |
|
298 | 298 | |
|
299 | 299 | Dump the current database as a Python model to stdout. |
|
300 | 300 | |
|
301 | 301 | NOTE: This is EXPERIMENTAL. |
|
302 | 302 | """ # TODO: get rid of EXPERIMENTAL label |
|
303 | 303 | engine = opts.pop('engine') |
|
304 | 304 | declarative = opts.get('declarative', False) |
|
305 | 305 | return ControlledSchema.create_model(engine, repository, declarative) |
|
306 | 306 | |
|
307 | 307 | |
|
308 | 308 | @catch_known_errors |
|
309 | 309 | @with_engine |
|
310 | 310 | def make_update_script_for_model(url, repository, oldmodel, model, **opts): |
|
311 | 311 | """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH |
|
312 | 312 | |
|
313 | 313 | Create a script changing the old Python model to the new (current) |
|
314 | 314 | Python model, sending to stdout. |
|
315 | 315 | |
|
316 | 316 | NOTE: This is EXPERIMENTAL. |
|
317 | 317 | """ # TODO: get rid of EXPERIMENTAL label |
|
318 | 318 | engine = opts.pop('engine') |
|
319 | 319 | return PythonScript.make_update_script_for_model( |
|
320 | 320 | engine, oldmodel, model, repository, **opts) |
|
321 | 321 | |
|
322 | 322 | |
|
323 | 323 | @with_engine |
|
324 | 324 | def update_db_from_model(url, repository, model, **opts): |
|
325 | 325 | """%prog update_db_from_model URL REPOSITORY_PATH MODEL |
|
326 | 326 | |
|
327 | 327 | Modify the database to match the structure of the current Python |
|
328 | 328 | model. This also sets the db_version number to the latest in the |
|
329 | 329 | repository. |
|
330 | 330 | |
|
331 | 331 | NOTE: This is EXPERIMENTAL. |
|
332 | 332 | """ # TODO: get rid of EXPERIMENTAL label |
|
333 | 333 | engine = opts.pop('engine') |
|
334 | 334 | schema = ControlledSchema(engine, repository) |
|
335 | 335 | schema.update_db_from_model(model) |
|
336 | 336 | |
|
337 | 337 | @with_engine |
|
338 | 338 | def _migrate(url, repository, version, upgrade, err, **opts): |
|
339 | 339 | engine = opts.pop('engine') |
|
340 | 340 | url = str(engine.url) |
|
341 | 341 | schema = ControlledSchema(engine, repository) |
|
342 | 342 | version = _migrate_version(schema, version, upgrade, err) |
|
343 | 343 | |
|
344 | 344 | changeset = schema.changeset(version) |
|
345 | 345 | for ver, change in changeset: |
|
346 | 346 | nextver = ver + changeset.step |
|
347 | 347 | log.info('%s -> %s... ', ver, nextver) |
|
348 | 348 | |
|
349 | 349 | if opts.get('preview_sql'): |
|
350 | 350 | if isinstance(change, PythonScript): |
|
351 | 351 | log.info(change.preview_sql(url, changeset.step, **opts)) |
|
352 | 352 | elif isinstance(change, SqlScript): |
|
353 | 353 | log.info(change.source()) |
|
354 | 354 | |
|
355 | 355 | elif opts.get('preview_py'): |
|
356 | 356 | if not isinstance(change, PythonScript): |
|
357 | 357 | raise exceptions.UsageError("Python source can be only displayed" |
|
358 | 358 | " for python migration files") |
|
359 | 359 | source_ver = max(ver, nextver) |
|
360 | 360 | module = schema.repository.version(source_ver).script().module |
|
361 | 361 | funcname = upgrade and "upgrade" or "downgrade" |
|
362 | 362 | func = getattr(module, funcname) |
|
363 | 363 | log.info(inspect.getsource(func)) |
|
364 | 364 | else: |
|
365 | 365 | schema.runchange(ver, change, changeset.step) |
|
366 | 366 | log.info('done') |
|
367 | 367 | |
|
368 | 368 | |
|
369 | 369 | def _migrate_version(schema, version, upgrade, err): |
|
370 | 370 | if version is None: |
|
371 | 371 | return version |
|
372 | 372 | # Version is specified: ensure we're upgrading in the right direction |
|
373 | 373 | # (current version < target version for upgrading; reverse for down) |
|
374 | 374 | version = VerNum(version) |
|
375 | 375 | cur = schema.version |
|
376 | 376 | if upgrade is not None: |
|
377 | 377 | if upgrade: |
|
378 | 378 | direction = cur <= version |
|
379 | 379 | else: |
|
380 | 380 | direction = cur >= version |
|
381 | 381 | if not direction: |
|
382 | 382 | raise exceptions.KnownError(err % (cur, version)) |
|
383 | 383 | return version |
@@ -1,27 +1,27 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | Configuration parser module. |
|
3 | 3 | """ |
|
4 | 4 | |
|
5 | 5 | from ConfigParser import ConfigParser |
|
6 | 6 | |
|
7 | from migrate.versioning.config import * | |
|
8 | from migrate.versioning import pathed | |
|
7 | from rhodecode.lib.dbmigrate.migrate.versioning.config import * | |
|
8 | from rhodecode.lib.dbmigrate.migrate.versioning import pathed | |
|
9 | 9 | |
|
10 | 10 | |
|
11 | 11 | class Parser(ConfigParser): |
|
12 | 12 | """A project configuration file.""" |
|
13 | 13 | |
|
14 | 14 | def to_dict(self, sections=None): |
|
15 | 15 | """It's easier to access config values like dictionaries""" |
|
16 | 16 | return self._sections |
|
17 | 17 | |
|
18 | 18 | |
|
19 | 19 | class Config(pathed.Pathed, Parser): |
|
20 | 20 | """Configuration class.""" |
|
21 | 21 | |
|
22 | 22 | def __init__(self, path, *p, **k): |
|
23 | 23 | """Confirm the config file exists; read it.""" |
|
24 | 24 | self.require_found(path) |
|
25 | 25 | pathed.Pathed.__init__(self, path) |
|
26 | 26 | Parser.__init__(self, *p, **k) |
|
27 | 27 | self.read(path) |
@@ -1,254 +1,253 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | Code to generate a Python model from a database or differences |
|
3 | 3 | between a model and database. |
|
4 | 4 | |
|
5 | 5 | Some of this is borrowed heavily from the AutoCode project at: |
|
6 | 6 | http://code.google.com/p/sqlautocode/ |
|
7 | 7 | """ |
|
8 | 8 | |
|
9 | 9 | import sys |
|
10 | 10 | import logging |
|
11 | 11 | |
|
12 | 12 | import sqlalchemy |
|
13 | 13 | |
|
14 | import migrate | |
|
15 | import migrate.changeset | |
|
16 | ||
|
14 | from rhodecode.lib.dbmigrate import migrate | |
|
15 | from rhodecode.lib.dbmigrate.migrate import changeset | |
|
17 | 16 | |
|
18 | 17 | log = logging.getLogger(__name__) |
|
19 | 18 | HEADER = """ |
|
20 | 19 | ## File autogenerated by genmodel.py |
|
21 | 20 | |
|
22 | 21 | from sqlalchemy import * |
|
23 | 22 | meta = MetaData() |
|
24 | 23 | """ |
|
25 | 24 | |
|
26 | 25 | DECLARATIVE_HEADER = """ |
|
27 | 26 | ## File autogenerated by genmodel.py |
|
28 | 27 | |
|
29 | 28 | from sqlalchemy import * |
|
30 | 29 | from sqlalchemy.ext import declarative |
|
31 | 30 | |
|
32 | 31 | Base = declarative.declarative_base() |
|
33 | 32 | """ |
|
34 | 33 | |
|
35 | 34 | |
|
36 | 35 | class ModelGenerator(object): |
|
37 | 36 | |
|
38 | 37 | def __init__(self, diff, engine, declarative=False): |
|
39 | 38 | self.diff = diff |
|
40 | 39 | self.engine = engine |
|
41 | 40 | self.declarative = declarative |
|
42 | 41 | |
|
43 | 42 | def column_repr(self, col): |
|
44 | 43 | kwarg = [] |
|
45 | 44 | if col.key != col.name: |
|
46 | 45 | kwarg.append('key') |
|
47 | 46 | if col.primary_key: |
|
48 | 47 | col.primary_key = True # otherwise it dumps it as 1 |
|
49 | 48 | kwarg.append('primary_key') |
|
50 | 49 | if not col.nullable: |
|
51 | 50 | kwarg.append('nullable') |
|
52 | 51 | if col.onupdate: |
|
53 | 52 | kwarg.append('onupdate') |
|
54 | 53 | if col.default: |
|
55 | 54 | if col.primary_key: |
|
56 | 55 | # I found that PostgreSQL automatically creates a |
|
57 | 56 | # default value for the sequence, but let's not show |
|
58 | 57 | # that. |
|
59 | 58 | pass |
|
60 | 59 | else: |
|
61 | 60 | kwarg.append('default') |
|
62 | 61 | ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg) |
|
63 | 62 | |
|
64 | 63 | # crs: not sure if this is good idea, but it gets rid of extra |
|
65 | 64 | # u'' |
|
66 | 65 | name = col.name.encode('utf8') |
|
67 | 66 | |
|
68 | 67 | type_ = col.type |
|
69 | 68 | for cls in col.type.__class__.__mro__: |
|
70 | 69 | if cls.__module__ == 'sqlalchemy.types' and \ |
|
71 | 70 | not cls.__name__.isupper(): |
|
72 | 71 | if cls is not type_.__class__: |
|
73 | 72 | type_ = cls() |
|
74 | 73 | break |
|
75 | 74 | |
|
76 | 75 | data = { |
|
77 | 76 | 'name': name, |
|
78 | 77 | 'type': type_, |
|
79 | 78 | 'constraints': ', '.join([repr(cn) for cn in col.constraints]), |
|
80 | 79 | 'args': ks and ks or ''} |
|
81 | 80 | |
|
82 | 81 | if data['constraints']: |
|
83 | 82 | if data['args']: |
|
84 | 83 | data['args'] = ',' + data['args'] |
|
85 | 84 | |
|
86 | 85 | if data['constraints'] or data['args']: |
|
87 | 86 | data['maybeComma'] = ',' |
|
88 | 87 | else: |
|
89 | 88 | data['maybeComma'] = '' |
|
90 | 89 | |
|
91 | 90 | commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data |
|
92 | 91 | commonStuff = commonStuff.strip() |
|
93 | 92 | data['commonStuff'] = commonStuff |
|
94 | 93 | if self.declarative: |
|
95 | 94 | return """%(name)s = Column(%(type)r%(commonStuff)s""" % data |
|
96 | 95 | else: |
|
97 | 96 | return """Column(%(name)r, %(type)r%(commonStuff)s""" % data |
|
98 | 97 | |
|
99 | 98 | def getTableDefn(self, table): |
|
100 | 99 | out = [] |
|
101 | 100 | tableName = table.name |
|
102 | 101 | if self.declarative: |
|
103 | 102 | out.append("class %(table)s(Base):" % {'table': tableName}) |
|
104 | 103 | out.append(" __tablename__ = '%(table)s'" % {'table': tableName}) |
|
105 | 104 | for col in table.columns: |
|
106 | 105 | out.append(" %s" % self.column_repr(col)) |
|
107 | 106 | else: |
|
108 | 107 | out.append("%(table)s = Table('%(table)s', meta," % \ |
|
109 | 108 | {'table': tableName}) |
|
110 | 109 | for col in table.columns: |
|
111 | 110 | out.append(" %s," % self.column_repr(col)) |
|
112 | 111 | out.append(")") |
|
113 | 112 | return out |
|
114 | 113 | |
|
115 | def _get_tables(self,missingA=False,missingB=False,modified=False): | |
|
114 | def _get_tables(self, missingA=False, missingB=False, modified=False): | |
|
116 | 115 | to_process = [] |
|
117 | for bool_,names,metadata in ( | |
|
118 | (missingA,self.diff.tables_missing_from_A,self.diff.metadataB), | |
|
119 | (missingB,self.diff.tables_missing_from_B,self.diff.metadataA), | |
|
120 | (modified,self.diff.tables_different,self.diff.metadataA), | |
|
116 | for bool_, names, metadata in ( | |
|
117 | (missingA, self.diff.tables_missing_from_A, self.diff.metadataB), | |
|
118 | (missingB, self.diff.tables_missing_from_B, self.diff.metadataA), | |
|
119 | (modified, self.diff.tables_different, self.diff.metadataA), | |
|
121 | 120 | ): |
|
122 | 121 | if bool_: |
|
123 | 122 | for name in names: |
|
124 | 123 | yield metadata.tables.get(name) |
|
125 | ||
|
124 | ||
|
126 | 125 | def toPython(self): |
|
127 | 126 | """Assume database is current and model is empty.""" |
|
128 | 127 | out = [] |
|
129 | 128 | if self.declarative: |
|
130 | 129 | out.append(DECLARATIVE_HEADER) |
|
131 | 130 | else: |
|
132 | 131 | out.append(HEADER) |
|
133 | 132 | out.append("") |
|
134 | 133 | for table in self._get_tables(missingA=True): |
|
135 | 134 | out.extend(self.getTableDefn(table)) |
|
136 | 135 | out.append("") |
|
137 | 136 | return '\n'.join(out) |
|
138 | 137 | |
|
139 | 138 | def toUpgradeDowngradePython(self, indent=' '): |
|
140 | 139 | ''' Assume model is most current and database is out-of-date. ''' |
|
141 | decls = ['from migrate.changeset import schema', | |
|
140 | decls = ['from rhodecode.lib.dbmigrate.migrate.changeset import schema', | |
|
142 | 141 | 'meta = MetaData()'] |
|
143 | 142 | for table in self._get_tables( |
|
144 | missingA=True,missingB=True,modified=True | |
|
143 | missingA=True, missingB=True, modified=True | |
|
145 | 144 | ): |
|
146 | 145 | decls.extend(self.getTableDefn(table)) |
|
147 | 146 | |
|
148 | 147 | upgradeCommands, downgradeCommands = [], [] |
|
149 | 148 | for tableName in self.diff.tables_missing_from_A: |
|
150 | 149 | upgradeCommands.append("%(table)s.drop()" % {'table': tableName}) |
|
151 | 150 | downgradeCommands.append("%(table)s.create()" % \ |
|
152 | 151 | {'table': tableName}) |
|
153 | 152 | for tableName in self.diff.tables_missing_from_B: |
|
154 | 153 | upgradeCommands.append("%(table)s.create()" % {'table': tableName}) |
|
155 | 154 | downgradeCommands.append("%(table)s.drop()" % {'table': tableName}) |
|
156 | 155 | |
|
157 | 156 | for tableName in self.diff.tables_different: |
|
158 | 157 | dbTable = self.diff.metadataB.tables[tableName] |
|
159 | 158 | missingInDatabase, missingInModel, diffDecl = \ |
|
160 | 159 | self.diff.colDiffs[tableName] |
|
161 | 160 | for col in missingInDatabase: |
|
162 | 161 | upgradeCommands.append('%s.columns[%r].create()' % ( |
|
163 | 162 | modelTable, col.name)) |
|
164 | 163 | downgradeCommands.append('%s.columns[%r].drop()' % ( |
|
165 | 164 | modelTable, col.name)) |
|
166 | 165 | for col in missingInModel: |
|
167 | 166 | upgradeCommands.append('%s.columns[%r].drop()' % ( |
|
168 | 167 | modelTable, col.name)) |
|
169 | 168 | downgradeCommands.append('%s.columns[%r].create()' % ( |
|
170 | 169 | modelTable, col.name)) |
|
171 | 170 | for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl: |
|
172 | 171 | upgradeCommands.append( |
|
173 | 172 | 'assert False, "Can\'t alter columns: %s:%s=>%s"', |
|
174 | 173 | modelTable, modelCol.name, databaseCol.name) |
|
175 | 174 | downgradeCommands.append( |
|
176 | 175 | 'assert False, "Can\'t alter columns: %s:%s=>%s"', |
|
177 | 176 | modelTable, modelCol.name, databaseCol.name) |
|
178 | 177 | pre_command = ' meta.bind = migrate_engine' |
|
179 | 178 | |
|
180 | 179 | return ( |
|
181 | 180 | '\n'.join(decls), |
|
182 | 181 | '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]), |
|
183 | 182 | '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands])) |
|
184 | 183 | |
|
185 | def _db_can_handle_this_change(self,td): | |
|
184 | def _db_can_handle_this_change(self, td): | |
|
186 | 185 | if (td.columns_missing_from_B |
|
187 | 186 | and not td.columns_missing_from_A |
|
188 | 187 | and not td.columns_different): |
|
189 | 188 | # Even sqlite can handle this. |
|
190 | 189 | return True |
|
191 | 190 | else: |
|
192 | 191 | return not self.engine.url.drivername.startswith('sqlite') |
|
193 | 192 | |
|
194 | 193 | def applyModel(self): |
|
195 | 194 | """Apply model to current database.""" |
|
196 | 195 | |
|
197 | 196 | meta = sqlalchemy.MetaData(self.engine) |
|
198 | 197 | |
|
199 | 198 | for table in self._get_tables(missingA=True): |
|
200 | 199 | table = table.tometadata(meta) |
|
201 | 200 | table.drop() |
|
202 | 201 | for table in self._get_tables(missingB=True): |
|
203 | 202 | table = table.tometadata(meta) |
|
204 | 203 | table.create() |
|
205 | 204 | for modelTable in self._get_tables(modified=True): |
|
206 | 205 | tableName = modelTable.name |
|
207 | 206 | modelTable = modelTable.tometadata(meta) |
|
208 | 207 | dbTable = self.diff.metadataB.tables[tableName] |
|
209 | 208 | |
|
210 | 209 | td = self.diff.tables_different[tableName] |
|
211 | ||
|
210 | ||
|
212 | 211 | if self._db_can_handle_this_change(td): |
|
213 | ||
|
212 | ||
|
214 | 213 | for col in td.columns_missing_from_B: |
|
215 | 214 | modelTable.columns[col].create() |
|
216 | 215 | for col in td.columns_missing_from_A: |
|
217 | 216 | dbTable.columns[col].drop() |
|
218 | 217 | # XXX handle column changes here. |
|
219 | 218 | else: |
|
220 | 219 | # Sqlite doesn't support drop column, so you have to |
|
221 | 220 | # do more: create temp table, copy data to it, drop |
|
222 | 221 | # old table, create new table, copy data back. |
|
223 | 222 | # |
|
224 | 223 | # I wonder if this is guaranteed to be unique? |
|
225 | 224 | tempName = '_temp_%s' % modelTable.name |
|
226 | 225 | |
|
227 | 226 | def getCopyStatement(): |
|
228 | 227 | preparer = self.engine.dialect.preparer |
|
229 | 228 | commonCols = [] |
|
230 | 229 | for modelCol in modelTable.columns: |
|
231 | 230 | if modelCol.name in dbTable.columns: |
|
232 | 231 | commonCols.append(modelCol.name) |
|
233 | 232 | commonColsStr = ', '.join(commonCols) |
|
234 | 233 | return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \ |
|
235 | 234 | (tableName, commonColsStr, commonColsStr, tempName) |
|
236 | 235 | |
|
237 | 236 | # Move the data in one transaction, so that we don't |
|
238 | 237 | # leave the database in a nasty state. |
|
239 | 238 | connection = self.engine.connect() |
|
240 | 239 | trans = connection.begin() |
|
241 | 240 | try: |
|
242 | 241 | connection.execute( |
|
243 | 242 | 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \ |
|
244 | 243 | (tempName, modelTable.name)) |
|
245 | 244 | # make sure the drop takes place inside our |
|
246 | 245 | # transaction with the bind parameter |
|
247 | 246 | modelTable.drop(bind=connection) |
|
248 | 247 | modelTable.create(bind=connection) |
|
249 | 248 | connection.execute(getCopyStatement()) |
|
250 | 249 | connection.execute('DROP TABLE %s' % tempName) |
|
251 | 250 | trans.commit() |
|
252 | 251 | except: |
|
253 | 252 | trans.rollback() |
|
254 | 253 | raise |
@@ -1,75 +1,75 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | A path/directory class. |
|
3 | 3 | """ |
|
4 | 4 | |
|
5 | 5 | import os |
|
6 | 6 | import shutil |
|
7 | 7 | import logging |
|
8 | 8 | |
|
9 | from migrate import exceptions | |
|
10 | from migrate.versioning.config import * | |
|
11 | from migrate.versioning.util import KeyedInstance | |
|
9 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
10 | from rhodecode.lib.dbmigrate.migrate.versioning.config import * | |
|
11 | from rhodecode.lib.dbmigrate.migrate.versioning.util import KeyedInstance | |
|
12 | 12 | |
|
13 | 13 | |
|
14 | 14 | log = logging.getLogger(__name__) |
|
15 | 15 | |
|
16 | 16 | class Pathed(KeyedInstance): |
|
17 | 17 | """ |
|
18 | 18 | A class associated with a path/directory tree. |
|
19 | 19 | |
|
20 | 20 | Only one instance of this class may exist for a particular file; |
|
21 | 21 | __new__ will return an existing instance if possible |
|
22 | 22 | """ |
|
23 | 23 | parent = None |
|
24 | 24 | |
|
25 | 25 | @classmethod |
|
26 | 26 | def _key(cls, path): |
|
27 | 27 | return str(path) |
|
28 | 28 | |
|
29 | 29 | def __init__(self, path): |
|
30 | 30 | self.path = path |
|
31 | 31 | if self.__class__.parent is not None: |
|
32 | 32 | self._init_parent(path) |
|
33 | 33 | |
|
34 | 34 | def _init_parent(self, path): |
|
35 | 35 | """Try to initialize this object's parent, if it has one""" |
|
36 | 36 | parent_path = self.__class__._parent_path(path) |
|
37 | 37 | self.parent = self.__class__.parent(parent_path) |
|
38 | 38 | log.debug("Getting parent %r:%r" % (self.__class__.parent, parent_path)) |
|
39 | 39 | self.parent._init_child(path, self) |
|
40 | 40 | |
|
41 | 41 | def _init_child(self, child, path): |
|
42 | 42 | """Run when a child of this object is initialized. |
|
43 | 43 | |
|
44 | 44 | Parameters: the child object; the path to this object (its |
|
45 | 45 | parent) |
|
46 | 46 | """ |
|
47 | 47 | |
|
48 | 48 | @classmethod |
|
49 | 49 | def _parent_path(cls, path): |
|
50 | 50 | """ |
|
51 | 51 | Fetch the path of this object's parent from this object's path. |
|
52 | 52 | """ |
|
53 | 53 | # os.path.dirname(), but strip directories like files (like |
|
54 | 54 | # unix basename) |
|
55 | 55 | # |
|
56 | 56 | # Treat directories like files... |
|
57 | 57 | if path[-1] == '/': |
|
58 | 58 | path = path[:-1] |
|
59 | 59 | ret = os.path.dirname(path) |
|
60 | 60 | return ret |
|
61 | 61 | |
|
62 | 62 | @classmethod |
|
63 | 63 | def require_notfound(cls, path): |
|
64 | 64 | """Ensures a given path does not already exist""" |
|
65 | 65 | if os.path.exists(path): |
|
66 | 66 | raise exceptions.PathFoundError(path) |
|
67 | 67 | |
|
68 | 68 | @classmethod |
|
69 | 69 | def require_found(cls, path): |
|
70 | 70 | """Ensures a given path already exists""" |
|
71 | 71 | if not os.path.exists(path): |
|
72 | 72 | raise exceptions.PathNotFoundError(path) |
|
73 | 73 | |
|
74 | 74 | def __str__(self): |
|
75 | 75 | return self.path |
@@ -1,231 +1,231 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | SQLAlchemy migrate repository management. |
|
3 | 3 | """ |
|
4 | 4 | import os |
|
5 | 5 | import shutil |
|
6 | 6 | import string |
|
7 | 7 | import logging |
|
8 | 8 | |
|
9 | 9 | from pkg_resources import resource_filename |
|
10 | 10 | from tempita import Template as TempitaTemplate |
|
11 | 11 | |
|
12 | from migrate import exceptions | |
|
13 | from migrate.versioning import version, pathed, cfgparse | |
|
14 | from migrate.versioning.template import Template | |
|
15 | from migrate.versioning.config import * | |
|
12 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
13 | from rhodecode.lib.dbmigrate.migrate.versioning import version, pathed, cfgparse | |
|
14 | from rhodecode.lib.dbmigrate.migrate.versioning.template import Template | |
|
15 | from rhodecode.lib.dbmigrate.migrate.versioning.config import * | |
|
16 | 16 | |
|
17 | 17 | |
|
18 | 18 | log = logging.getLogger(__name__) |
|
19 | 19 | |
|
20 | 20 | class Changeset(dict): |
|
21 | 21 | """A collection of changes to be applied to a database. |
|
22 | 22 | |
|
23 | 23 | Changesets are bound to a repository and manage a set of |
|
24 | 24 | scripts from that repository. |
|
25 | 25 | |
|
26 | 26 | Behaves like a dict, for the most part. Keys are ordered based on step value. |
|
27 | 27 | """ |
|
28 | 28 | |
|
29 | 29 | def __init__(self, start, *changes, **k): |
|
30 | 30 | """ |
|
31 | 31 | Give a start version; step must be explicitly stated. |
|
32 | 32 | """ |
|
33 | 33 | self.step = k.pop('step', 1) |
|
34 | 34 | self.start = version.VerNum(start) |
|
35 | 35 | self.end = self.start |
|
36 | 36 | for change in changes: |
|
37 | 37 | self.add(change) |
|
38 | 38 | |
|
39 | 39 | def __iter__(self): |
|
40 | 40 | return iter(self.items()) |
|
41 | 41 | |
|
42 | 42 | def keys(self): |
|
43 | 43 | """ |
|
44 | 44 | In a series of upgrades x -> y, keys are version x. Sorted. |
|
45 | 45 | """ |
|
46 | 46 | ret = super(Changeset, self).keys() |
|
47 | 47 | # Reverse order if downgrading |
|
48 | 48 | ret.sort(reverse=(self.step < 1)) |
|
49 | 49 | return ret |
|
50 | 50 | |
|
51 | 51 | def values(self): |
|
52 | 52 | return [self[k] for k in self.keys()] |
|
53 | 53 | |
|
54 | 54 | def items(self): |
|
55 | 55 | return zip(self.keys(), self.values()) |
|
56 | 56 | |
|
57 | 57 | def add(self, change): |
|
58 | 58 | """Add new change to changeset""" |
|
59 | 59 | key = self.end |
|
60 | 60 | self.end += self.step |
|
61 | 61 | self[key] = change |
|
62 | 62 | |
|
63 | 63 | def run(self, *p, **k): |
|
64 | 64 | """Run the changeset scripts""" |
|
65 | 65 | for version, script in self: |
|
66 | 66 | script.run(*p, **k) |
|
67 | 67 | |
|
68 | 68 | |
|
69 | 69 | class Repository(pathed.Pathed): |
|
70 | 70 | """A project's change script repository""" |
|
71 | 71 | |
|
72 | 72 | _config = 'migrate.cfg' |
|
73 | 73 | _versions = 'versions' |
|
74 | 74 | |
|
75 | 75 | def __init__(self, path): |
|
76 | 76 | log.debug('Loading repository %s...' % path) |
|
77 | 77 | self.verify(path) |
|
78 | 78 | super(Repository, self).__init__(path) |
|
79 | 79 | self.config = cfgparse.Config(os.path.join(self.path, self._config)) |
|
80 | 80 | self.versions = version.Collection(os.path.join(self.path, |
|
81 | 81 | self._versions)) |
|
82 | 82 | log.debug('Repository %s loaded successfully' % path) |
|
83 | 83 | log.debug('Config: %r' % self.config.to_dict()) |
|
84 | 84 | |
|
85 | 85 | @classmethod |
|
86 | 86 | def verify(cls, path): |
|
87 | 87 | """ |
|
88 | 88 | Ensure the target path is a valid repository. |
|
89 | 89 | |
|
90 | 90 | :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>` |
|
91 | 91 | """ |
|
92 | 92 | # Ensure the existence of required files |
|
93 | 93 | try: |
|
94 | 94 | cls.require_found(path) |
|
95 | 95 | cls.require_found(os.path.join(path, cls._config)) |
|
96 | 96 | cls.require_found(os.path.join(path, cls._versions)) |
|
97 | 97 | except exceptions.PathNotFoundError, e: |
|
98 | 98 | raise exceptions.InvalidRepositoryError(path) |
|
99 | 99 | |
|
100 | 100 | @classmethod |
|
101 | 101 | def prepare_config(cls, tmpl_dir, name, options=None): |
|
102 | 102 | """ |
|
103 | 103 | Prepare a project configuration file for a new project. |
|
104 | 104 | |
|
105 | 105 | :param tmpl_dir: Path to Repository template |
|
106 | 106 | :param config_file: Name of the config file in Repository template |
|
107 | 107 | :param name: Repository name |
|
108 | 108 | :type tmpl_dir: string |
|
109 | 109 | :type config_file: string |
|
110 | 110 | :type name: string |
|
111 | 111 | :returns: Populated config file |
|
112 | 112 | """ |
|
113 | 113 | if options is None: |
|
114 | 114 | options = {} |
|
115 | 115 | options.setdefault('version_table', 'migrate_version') |
|
116 | 116 | options.setdefault('repository_id', name) |
|
117 | 117 | options.setdefault('required_dbs', []) |
|
118 | 118 | |
|
119 | 119 | tmpl = open(os.path.join(tmpl_dir, cls._config)).read() |
|
120 | 120 | ret = TempitaTemplate(tmpl).substitute(options) |
|
121 | 121 | |
|
122 | 122 | # cleanup |
|
123 | 123 | del options['__template_name__'] |
|
124 | 124 | |
|
125 | 125 | return ret |
|
126 | 126 | |
|
127 | 127 | @classmethod |
|
128 | 128 | def create(cls, path, name, **opts): |
|
129 | 129 | """Create a repository at a specified path""" |
|
130 | 130 | cls.require_notfound(path) |
|
131 | 131 | theme = opts.pop('templates_theme', None) |
|
132 | 132 | t_path = opts.pop('templates_path', None) |
|
133 | 133 | |
|
134 | 134 | # Create repository |
|
135 | 135 | tmpl_dir = Template(t_path).get_repository(theme=theme) |
|
136 | 136 | shutil.copytree(tmpl_dir, path) |
|
137 | 137 | |
|
138 | 138 | # Edit config defaults |
|
139 | 139 | config_text = cls.prepare_config(tmpl_dir, name, options=opts) |
|
140 | 140 | fd = open(os.path.join(path, cls._config), 'w') |
|
141 | 141 | fd.write(config_text) |
|
142 | 142 | fd.close() |
|
143 | 143 | |
|
144 | 144 | opts['repository_name'] = name |
|
145 | 145 | |
|
146 | 146 | # Create a management script |
|
147 | 147 | manager = os.path.join(path, 'manage.py') |
|
148 | 148 | Repository.create_manage_file(manager, templates_theme=theme, |
|
149 | 149 | templates_path=t_path, **opts) |
|
150 | 150 | |
|
151 | 151 | return cls(path) |
|
152 | 152 | |
|
153 | 153 | def create_script(self, description, **k): |
|
154 | 154 | """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`""" |
|
155 | 155 | self.versions.create_new_python_version(description, **k) |
|
156 | 156 | |
|
157 | 157 | def create_script_sql(self, database, **k): |
|
158 | 158 | """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`""" |
|
159 | 159 | self.versions.create_new_sql_version(database, **k) |
|
160 | 160 | |
|
161 | 161 | @property |
|
162 | 162 | def latest(self): |
|
163 | 163 | """API to :attr:`migrate.versioning.version.Collection.latest`""" |
|
164 | 164 | return self.versions.latest |
|
165 | 165 | |
|
166 | 166 | @property |
|
167 | 167 | def version_table(self): |
|
168 | 168 | """Returns version_table name specified in config""" |
|
169 | 169 | return self.config.get('db_settings', 'version_table') |
|
170 | 170 | |
|
171 | 171 | @property |
|
172 | 172 | def id(self): |
|
173 | 173 | """Returns repository id specified in config""" |
|
174 | 174 | return self.config.get('db_settings', 'repository_id') |
|
175 | 175 | |
|
176 | 176 | def version(self, *p, **k): |
|
177 | 177 | """API to :attr:`migrate.versioning.version.Collection.version`""" |
|
178 | 178 | return self.versions.version(*p, **k) |
|
179 | 179 | |
|
180 | 180 | @classmethod |
|
181 | 181 | def clear(cls): |
|
182 | 182 | # TODO: deletes repo |
|
183 | 183 | super(Repository, cls).clear() |
|
184 | 184 | version.Collection.clear() |
|
185 | 185 | |
|
186 | 186 | def changeset(self, database, start, end=None): |
|
187 | 187 | """Create a changeset to migrate this database from ver. start to end/latest. |
|
188 | 188 | |
|
189 | 189 | :param database: name of database to generate changeset |
|
190 | 190 | :param start: version to start at |
|
191 | 191 | :param end: version to end at (latest if None given) |
|
192 | 192 | :type database: string |
|
193 | 193 | :type start: int |
|
194 | 194 | :type end: int |
|
195 | 195 | :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>` |
|
196 | 196 | """ |
|
197 | 197 | start = version.VerNum(start) |
|
198 | 198 | |
|
199 | 199 | if end is None: |
|
200 | 200 | end = self.latest |
|
201 | 201 | else: |
|
202 | 202 | end = version.VerNum(end) |
|
203 | 203 | |
|
204 | 204 | if start <= end: |
|
205 | 205 | step = 1 |
|
206 | 206 | range_mod = 1 |
|
207 | 207 | op = 'upgrade' |
|
208 | 208 | else: |
|
209 | 209 | step = -1 |
|
210 | 210 | range_mod = 0 |
|
211 | 211 | op = 'downgrade' |
|
212 | 212 | |
|
213 | 213 | versions = range(start + range_mod, end + range_mod, step) |
|
214 | 214 | changes = [self.version(v).script(database, op) for v in versions] |
|
215 | 215 | ret = Changeset(start, step=step, *changes) |
|
216 | 216 | return ret |
|
217 | 217 | |
|
218 | 218 | @classmethod |
|
219 | 219 | def create_manage_file(cls, file_, **opts): |
|
220 | 220 | """Create a project management script (manage.py) |
|
221 | 221 | |
|
222 | 222 | :param file_: Destination file to be written |
|
223 | 223 | :param opts: Options that are passed to :func:`migrate.versioning.shell.main` |
|
224 | 224 | """ |
|
225 | 225 | mng_file = Template(opts.pop('templates_path', None))\ |
|
226 | 226 | .get_manage(theme=opts.pop('templates_theme', None)) |
|
227 | 227 | |
|
228 | 228 | tmpl = open(mng_file).read() |
|
229 | 229 | fd = open(file_, 'w') |
|
230 | 230 | fd.write(TempitaTemplate(tmpl).substitute(opts)) |
|
231 | 231 | fd.close() |
@@ -1,213 +1,213 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | Database schema version management. |
|
3 | 3 | """ |
|
4 | 4 | import sys |
|
5 | 5 | import logging |
|
6 | 6 | |
|
7 | 7 | from sqlalchemy import (Table, Column, MetaData, String, Text, Integer, |
|
8 | 8 | create_engine) |
|
9 | 9 | from sqlalchemy.sql import and_ |
|
10 | 10 | from sqlalchemy import exceptions as sa_exceptions |
|
11 | 11 | from sqlalchemy.sql import bindparam |
|
12 | 12 | |
|
13 | from migrate import exceptions | |
|
14 | from migrate.versioning import genmodel, schemadiff | |
|
15 | from migrate.versioning.repository import Repository | |
|
16 | from migrate.versioning.util import load_model | |
|
17 | from migrate.versioning.version import VerNum | |
|
13 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
14 | from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff | |
|
15 | from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository | |
|
16 | from rhodecode.lib.dbmigrate.migrate.versioning.util import load_model | |
|
17 | from rhodecode.lib.dbmigrate.migrate.versioning.version import VerNum | |
|
18 | 18 | |
|
19 | 19 | |
|
20 | 20 | log = logging.getLogger(__name__) |
|
21 | 21 | |
|
22 | 22 | class ControlledSchema(object): |
|
23 | 23 | """A database under version control""" |
|
24 | 24 | |
|
25 | 25 | def __init__(self, engine, repository): |
|
26 | 26 | if isinstance(repository, basestring): |
|
27 | 27 | repository = Repository(repository) |
|
28 | 28 | self.engine = engine |
|
29 | 29 | self.repository = repository |
|
30 | 30 | self.meta = MetaData(engine) |
|
31 | 31 | self.load() |
|
32 | 32 | |
|
33 | 33 | def __eq__(self, other): |
|
34 | 34 | """Compare two schemas by repositories and versions""" |
|
35 | 35 | return (self.repository is other.repository \ |
|
36 | 36 | and self.version == other.version) |
|
37 | 37 | |
|
38 | 38 | def load(self): |
|
39 | 39 | """Load controlled schema version info from DB""" |
|
40 | 40 | tname = self.repository.version_table |
|
41 | 41 | try: |
|
42 | 42 | if not hasattr(self, 'table') or self.table is None: |
|
43 | 43 | self.table = Table(tname, self.meta, autoload=True) |
|
44 | 44 | |
|
45 | 45 | result = self.engine.execute(self.table.select( |
|
46 | 46 | self.table.c.repository_id == str(self.repository.id))) |
|
47 | 47 | |
|
48 | 48 | data = list(result)[0] |
|
49 | 49 | except: |
|
50 | 50 | cls, exc, tb = sys.exc_info() |
|
51 | 51 | raise exceptions.DatabaseNotControlledError, exc.__str__(), tb |
|
52 | 52 | |
|
53 | 53 | self.version = data['version'] |
|
54 | 54 | return data |
|
55 | 55 | |
|
56 | 56 | def drop(self): |
|
57 | 57 | """ |
|
58 | 58 | Remove version control from a database. |
|
59 | 59 | """ |
|
60 | 60 | try: |
|
61 | 61 | self.table.drop() |
|
62 | 62 | except (sa_exceptions.SQLError): |
|
63 | 63 | raise exceptions.DatabaseNotControlledError(str(self.table)) |
|
64 | 64 | |
|
65 | 65 | def changeset(self, version=None): |
|
66 | 66 | """API to Changeset creation. |
|
67 | 67 | |
|
68 | 68 | Uses self.version for start version and engine.name |
|
69 | 69 | to get database name. |
|
70 | 70 | """ |
|
71 | 71 | database = self.engine.name |
|
72 | 72 | start_ver = self.version |
|
73 | 73 | changeset = self.repository.changeset(database, start_ver, version) |
|
74 | 74 | return changeset |
|
75 | 75 | |
|
76 | 76 | def runchange(self, ver, change, step): |
|
77 | 77 | startver = ver |
|
78 | 78 | endver = ver + step |
|
79 | 79 | # Current database version must be correct! Don't run if corrupt! |
|
80 | 80 | if self.version != startver: |
|
81 | 81 | raise exceptions.InvalidVersionError("%s is not %s" % \ |
|
82 | 82 | (self.version, startver)) |
|
83 | 83 | # Run the change |
|
84 | 84 | change.run(self.engine, step) |
|
85 | 85 | |
|
86 | 86 | # Update/refresh database version |
|
87 | 87 | self.update_repository_table(startver, endver) |
|
88 | 88 | self.load() |
|
89 | 89 | |
|
90 | 90 | def update_repository_table(self, startver, endver): |
|
91 | 91 | """Update version_table with new information""" |
|
92 | 92 | update = self.table.update(and_(self.table.c.version == int(startver), |
|
93 | 93 | self.table.c.repository_id == str(self.repository.id))) |
|
94 | 94 | self.engine.execute(update, version=int(endver)) |
|
95 | 95 | |
|
96 | 96 | def upgrade(self, version=None): |
|
97 | 97 | """ |
|
98 | 98 | Upgrade (or downgrade) to a specified version, or latest version. |
|
99 | 99 | """ |
|
100 | 100 | changeset = self.changeset(version) |
|
101 | 101 | for ver, change in changeset: |
|
102 | 102 | self.runchange(ver, change, changeset.step) |
|
103 | 103 | |
|
104 | 104 | def update_db_from_model(self, model): |
|
105 | 105 | """ |
|
106 | 106 | Modify the database to match the structure of the current Python model. |
|
107 | 107 | """ |
|
108 | 108 | model = load_model(model) |
|
109 | 109 | |
|
110 | 110 | diff = schemadiff.getDiffOfModelAgainstDatabase( |
|
111 | 111 | model, self.engine, excludeTables=[self.repository.version_table] |
|
112 | 112 | ) |
|
113 | 113 | genmodel.ModelGenerator(diff,self.engine).applyModel() |
|
114 | 114 | |
|
115 | 115 | self.update_repository_table(self.version, int(self.repository.latest)) |
|
116 | 116 | |
|
117 | 117 | self.load() |
|
118 | 118 | |
|
119 | 119 | @classmethod |
|
120 | 120 | def create(cls, engine, repository, version=None): |
|
121 | 121 | """ |
|
122 | 122 | Declare a database to be under a repository's version control. |
|
123 | 123 | |
|
124 | 124 | :raises: :exc:`DatabaseAlreadyControlledError` |
|
125 | 125 | :returns: :class:`ControlledSchema` |
|
126 | 126 | """ |
|
127 | 127 | # Confirm that the version # is valid: positive, integer, |
|
128 | 128 | # exists in repos |
|
129 | 129 | if isinstance(repository, basestring): |
|
130 | 130 | repository = Repository(repository) |
|
131 | 131 | version = cls._validate_version(repository, version) |
|
132 | 132 | table = cls._create_table_version(engine, repository, version) |
|
133 | 133 | # TODO: history table |
|
134 | 134 | # Load repository information and return |
|
135 | 135 | return cls(engine, repository) |
|
136 | 136 | |
|
137 | 137 | @classmethod |
|
138 | 138 | def _validate_version(cls, repository, version): |
|
139 | 139 | """ |
|
140 | 140 | Ensures this is a valid version number for this repository. |
|
141 | 141 | |
|
142 | 142 | :raises: :exc:`InvalidVersionError` if invalid |
|
143 | 143 | :return: valid version number |
|
144 | 144 | """ |
|
145 | 145 | if version is None: |
|
146 | 146 | version = 0 |
|
147 | 147 | try: |
|
148 | 148 | version = VerNum(version) # raises valueerror |
|
149 | 149 | if version < 0 or version > repository.latest: |
|
150 | 150 | raise ValueError() |
|
151 | 151 | except ValueError: |
|
152 | 152 | raise exceptions.InvalidVersionError(version) |
|
153 | 153 | return version |
|
154 | 154 | |
|
155 | 155 | @classmethod |
|
156 | 156 | def _create_table_version(cls, engine, repository, version): |
|
157 | 157 | """ |
|
158 | 158 | Creates the versioning table in a database. |
|
159 | 159 | |
|
160 | 160 | :raises: :exc:`DatabaseAlreadyControlledError` |
|
161 | 161 | """ |
|
162 | 162 | # Create tables |
|
163 | 163 | tname = repository.version_table |
|
164 | 164 | meta = MetaData(engine) |
|
165 | 165 | |
|
166 | 166 | table = Table( |
|
167 | 167 | tname, meta, |
|
168 | 168 | Column('repository_id', String(250), primary_key=True), |
|
169 | 169 | Column('repository_path', Text), |
|
170 | 170 | Column('version', Integer), ) |
|
171 | 171 | |
|
172 | 172 | # there can be multiple repositories/schemas in the same db |
|
173 | 173 | if not table.exists(): |
|
174 | 174 | table.create() |
|
175 | 175 | |
|
176 | 176 | # test for existing repository_id |
|
177 | 177 | s = table.select(table.c.repository_id == bindparam("repository_id")) |
|
178 | 178 | result = engine.execute(s, repository_id=repository.id) |
|
179 | 179 | if result.fetchone(): |
|
180 | 180 | raise exceptions.DatabaseAlreadyControlledError |
|
181 | 181 | |
|
182 | 182 | # Insert data |
|
183 | 183 | engine.execute(table.insert().values( |
|
184 | 184 | repository_id=repository.id, |
|
185 | 185 | repository_path=repository.path, |
|
186 | 186 | version=int(version))) |
|
187 | 187 | return table |
|
188 | 188 | |
|
189 | 189 | @classmethod |
|
190 | 190 | def compare_model_to_db(cls, engine, model, repository): |
|
191 | 191 | """ |
|
192 | 192 | Compare the current model against the current database. |
|
193 | 193 | """ |
|
194 | 194 | if isinstance(repository, basestring): |
|
195 | 195 | repository = Repository(repository) |
|
196 | 196 | model = load_model(model) |
|
197 | 197 | |
|
198 | 198 | diff = schemadiff.getDiffOfModelAgainstDatabase( |
|
199 | 199 | model, engine, excludeTables=[repository.version_table]) |
|
200 | 200 | return diff |
|
201 | 201 | |
|
202 | 202 | @classmethod |
|
203 | 203 | def create_model(cls, engine, repository, declarative=False): |
|
204 | 204 | """ |
|
205 | 205 | Dump the current database as a Python model. |
|
206 | 206 | """ |
|
207 | 207 | if isinstance(repository, basestring): |
|
208 | 208 | repository = Repository(repository) |
|
209 | 209 | |
|
210 | 210 | diff = schemadiff.getDiffOfModelAgainstDatabase( |
|
211 | 211 | MetaData(), engine, excludeTables=[repository.version_table] |
|
212 | 212 | ) |
|
213 | 213 | return genmodel.ModelGenerator(diff, engine, declarative).toPython() |
@@ -1,285 +1,285 b'' | |||
|
1 | 1 | """ |
|
2 | 2 | Schema differencing support. |
|
3 | 3 | """ |
|
4 | 4 | |
|
5 | 5 | import logging |
|
6 | 6 | import sqlalchemy |
|
7 | 7 | |
|
8 | from migrate.changeset import SQLA_06 | |
|
8 | from rhodecode.lib.dbmigrate.migrate.changeset import SQLA_06 | |
|
9 | 9 | from sqlalchemy.types import Float |
|
10 | 10 | |
|
11 | 11 | log = logging.getLogger(__name__) |
|
12 | 12 | |
|
13 | 13 | def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None): |
|
14 | 14 | """ |
|
15 | 15 | Return differences of model against database. |
|
16 | 16 | |
|
17 | 17 | :return: object which will evaluate to :keyword:`True` if there \ |
|
18 | 18 | are differences else :keyword:`False`. |
|
19 | 19 | """ |
|
20 | 20 | return SchemaDiff(metadata, |
|
21 | 21 | sqlalchemy.MetaData(engine, reflect=True), |
|
22 | 22 | labelA='model', |
|
23 | 23 | labelB='database', |
|
24 | 24 | excludeTables=excludeTables) |
|
25 | 25 | |
|
26 | 26 | |
|
27 | 27 | def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None): |
|
28 | 28 | """ |
|
29 | 29 | Return differences of model against another model. |
|
30 | 30 | |
|
31 | 31 | :return: object which will evaluate to :keyword:`True` if there \ |
|
32 | 32 | are differences else :keyword:`False`. |
|
33 | 33 | """ |
|
34 | 34 | return SchemaDiff(metadataA, metadataB, excludeTables) |
|
35 | 35 | |
|
36 | 36 | |
|
37 | 37 | class ColDiff(object): |
|
38 | 38 | """ |
|
39 | 39 | Container for differences in one :class:`~sqlalchemy.schema.Column` |
|
40 | 40 | between two :class:`~sqlalchemy.schema.Table` instances, ``A`` |
|
41 | 41 | and ``B``. |
|
42 | 42 | |
|
43 | 43 | .. attribute:: col_A |
|
44 | 44 | |
|
45 | 45 | The :class:`~sqlalchemy.schema.Column` object for A. |
|
46 | 46 | |
|
47 | 47 | .. attribute:: col_B |
|
48 | 48 | |
|
49 | 49 | The :class:`~sqlalchemy.schema.Column` object for B. |
|
50 | 50 | |
|
51 | 51 | .. attribute:: type_A |
|
52 | 52 | |
|
53 | 53 | The most generic type of the :class:`~sqlalchemy.schema.Column` |
|
54 | 54 | object in A. |
|
55 | 55 | |
|
56 | 56 | .. attribute:: type_B |
|
57 | 57 | |
|
58 | 58 | The most generic type of the :class:`~sqlalchemy.schema.Column` |
|
59 | 59 | object in A. |
|
60 | 60 | |
|
61 | 61 | """ |
|
62 | 62 | |
|
63 | 63 | diff = False |
|
64 | 64 | |
|
65 | 65 | def __init__(self,col_A,col_B): |
|
66 | 66 | self.col_A = col_A |
|
67 | 67 | self.col_B = col_B |
|
68 | 68 | |
|
69 | 69 | self.type_A = col_A.type |
|
70 | 70 | self.type_B = col_B.type |
|
71 | 71 | |
|
72 | 72 | self.affinity_A = self.type_A._type_affinity |
|
73 | 73 | self.affinity_B = self.type_B._type_affinity |
|
74 | 74 | |
|
75 | 75 | if self.affinity_A is not self.affinity_B: |
|
76 | 76 | self.diff = True |
|
77 | 77 | return |
|
78 | 78 | |
|
79 | 79 | if isinstance(self.type_A,Float) or isinstance(self.type_B,Float): |
|
80 | 80 | if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)): |
|
81 | 81 | self.diff=True |
|
82 | 82 | return |
|
83 | 83 | |
|
84 | 84 | for attr in ('precision','scale','length'): |
|
85 | 85 | A = getattr(self.type_A,attr,None) |
|
86 | 86 | B = getattr(self.type_B,attr,None) |
|
87 | 87 | if not (A is None or B is None) and A!=B: |
|
88 | 88 | self.diff=True |
|
89 | 89 | return |
|
90 | 90 | |
|
91 | 91 | def __nonzero__(self): |
|
92 | 92 | return self.diff |
|
93 | 93 | |
|
94 | 94 | class TableDiff(object): |
|
95 | 95 | """ |
|
96 | 96 | Container for differences in one :class:`~sqlalchemy.schema.Table` |
|
97 | 97 | between two :class:`~sqlalchemy.schema.MetaData` instances, ``A`` |
|
98 | 98 | and ``B``. |
|
99 | 99 | |
|
100 | 100 | .. attribute:: columns_missing_from_A |
|
101 | 101 | |
|
102 | 102 | A sequence of column names that were found in B but weren't in |
|
103 | 103 | A. |
|
104 | 104 | |
|
105 | 105 | .. attribute:: columns_missing_from_B |
|
106 | 106 | |
|
107 | 107 | A sequence of column names that were found in A but weren't in |
|
108 | 108 | B. |
|
109 | 109 | |
|
110 | 110 | .. attribute:: columns_different |
|
111 | 111 | |
|
112 | 112 | A dictionary containing information about columns that were |
|
113 | 113 | found to be different. |
|
114 | 114 | It maps column names to a :class:`ColDiff` objects describing the |
|
115 | 115 | differences found. |
|
116 | 116 | """ |
|
117 | 117 | __slots__ = ( |
|
118 | 118 | 'columns_missing_from_A', |
|
119 | 119 | 'columns_missing_from_B', |
|
120 | 120 | 'columns_different', |
|
121 | 121 | ) |
|
122 | 122 | |
|
123 | 123 | def __nonzero__(self): |
|
124 | 124 | return bool( |
|
125 | 125 | self.columns_missing_from_A or |
|
126 | 126 | self.columns_missing_from_B or |
|
127 | 127 | self.columns_different |
|
128 | 128 | ) |
|
129 | 129 | |
|
130 | 130 | class SchemaDiff(object): |
|
131 | 131 | """ |
|
132 | 132 | Compute the difference between two :class:`~sqlalchemy.schema.MetaData` |
|
133 | 133 | objects. |
|
134 | 134 | |
|
135 | 135 | The string representation of a :class:`SchemaDiff` will summarise |
|
136 | 136 | the changes found between the two |
|
137 | 137 | :class:`~sqlalchemy.schema.MetaData` objects. |
|
138 | 138 | |
|
139 | 139 | The length of a :class:`SchemaDiff` will give the number of |
|
140 | 140 | changes found, enabling it to be used much like a boolean in |
|
141 | 141 | expressions. |
|
142 | 142 | |
|
143 | 143 | :param metadataA: |
|
144 | 144 | First :class:`~sqlalchemy.schema.MetaData` to compare. |
|
145 | 145 | |
|
146 | 146 | :param metadataB: |
|
147 | 147 | Second :class:`~sqlalchemy.schema.MetaData` to compare. |
|
148 | 148 | |
|
149 | 149 | :param labelA: |
|
150 | 150 | The label to use in messages about the first |
|
151 | 151 | :class:`~sqlalchemy.schema.MetaData`. |
|
152 | 152 | |
|
153 | 153 | :param labelB: |
|
154 | 154 | The label to use in messages about the second |
|
155 | 155 | :class:`~sqlalchemy.schema.MetaData`. |
|
156 | 156 | |
|
157 | 157 | :param excludeTables: |
|
158 | 158 | A sequence of table names to exclude. |
|
159 | 159 | |
|
160 | 160 | .. attribute:: tables_missing_from_A |
|
161 | 161 | |
|
162 | 162 | A sequence of table names that were found in B but weren't in |
|
163 | 163 | A. |
|
164 | 164 | |
|
165 | 165 | .. attribute:: tables_missing_from_B |
|
166 | 166 | |
|
167 | 167 | A sequence of table names that were found in A but weren't in |
|
168 | 168 | B. |
|
169 | 169 | |
|
170 | 170 | .. attribute:: tables_different |
|
171 | 171 | |
|
172 | 172 | A dictionary containing information about tables that were found |
|
173 | 173 | to be different. |
|
174 | 174 | It maps table names to a :class:`TableDiff` objects describing the |
|
175 | 175 | differences found. |
|
176 | 176 | """ |
|
177 | 177 | |
|
178 | 178 | def __init__(self, |
|
179 | 179 | metadataA, metadataB, |
|
180 | 180 | labelA='metadataA', |
|
181 | 181 | labelB='metadataB', |
|
182 | 182 | excludeTables=None): |
|
183 | 183 | |
|
184 | 184 | self.metadataA, self.metadataB = metadataA, metadataB |
|
185 | 185 | self.labelA, self.labelB = labelA, labelB |
|
186 | 186 | self.label_width = max(len(labelA),len(labelB)) |
|
187 | 187 | excludeTables = set(excludeTables or []) |
|
188 | 188 | |
|
189 | 189 | A_table_names = set(metadataA.tables.keys()) |
|
190 | 190 | B_table_names = set(metadataB.tables.keys()) |
|
191 | 191 | |
|
192 | 192 | self.tables_missing_from_A = sorted( |
|
193 | 193 | B_table_names - A_table_names - excludeTables |
|
194 | 194 | ) |
|
195 | 195 | self.tables_missing_from_B = sorted( |
|
196 | 196 | A_table_names - B_table_names - excludeTables |
|
197 | 197 | ) |
|
198 | 198 | |
|
199 | 199 | self.tables_different = {} |
|
200 | 200 | for table_name in A_table_names.intersection(B_table_names): |
|
201 | 201 | |
|
202 | 202 | td = TableDiff() |
|
203 | 203 | |
|
204 | 204 | A_table = metadataA.tables[table_name] |
|
205 | 205 | B_table = metadataB.tables[table_name] |
|
206 | 206 | |
|
207 | 207 | A_column_names = set(A_table.columns.keys()) |
|
208 | 208 | B_column_names = set(B_table.columns.keys()) |
|
209 | 209 | |
|
210 | 210 | td.columns_missing_from_A = sorted( |
|
211 | 211 | B_column_names - A_column_names |
|
212 | 212 | ) |
|
213 | 213 | |
|
214 | 214 | td.columns_missing_from_B = sorted( |
|
215 | 215 | A_column_names - B_column_names |
|
216 | 216 | ) |
|
217 | 217 | |
|
218 | 218 | td.columns_different = {} |
|
219 | 219 | |
|
220 | 220 | for col_name in A_column_names.intersection(B_column_names): |
|
221 | 221 | |
|
222 | 222 | cd = ColDiff( |
|
223 | 223 | A_table.columns.get(col_name), |
|
224 | 224 | B_table.columns.get(col_name) |
|
225 | 225 | ) |
|
226 | 226 | |
|
227 | 227 | if cd: |
|
228 | 228 | td.columns_different[col_name]=cd |
|
229 | 229 | |
|
230 | 230 | # XXX - index and constraint differences should |
|
231 | 231 | # be checked for here |
|
232 | 232 | |
|
233 | 233 | if td: |
|
234 | 234 | self.tables_different[table_name]=td |
|
235 | 235 | |
|
236 | 236 | def __str__(self): |
|
237 | 237 | ''' Summarize differences. ''' |
|
238 | 238 | out = [] |
|
239 | 239 | column_template =' %%%is: %%r' % self.label_width |
|
240 | 240 | |
|
241 | 241 | for names,label in ( |
|
242 | 242 | (self.tables_missing_from_A,self.labelA), |
|
243 | 243 | (self.tables_missing_from_B,self.labelB), |
|
244 | 244 | ): |
|
245 | 245 | if names: |
|
246 | 246 | out.append( |
|
247 | 247 | ' tables missing from %s: %s' % ( |
|
248 | 248 | label,', '.join(sorted(names)) |
|
249 | 249 | ) |
|
250 | 250 | ) |
|
251 | 251 | |
|
252 | 252 | for name,td in sorted(self.tables_different.items()): |
|
253 | 253 | out.append( |
|
254 | 254 | ' table with differences: %s' % name |
|
255 | 255 | ) |
|
256 | 256 | for names,label in ( |
|
257 | 257 | (td.columns_missing_from_A,self.labelA), |
|
258 | 258 | (td.columns_missing_from_B,self.labelB), |
|
259 | 259 | ): |
|
260 | 260 | if names: |
|
261 | 261 | out.append( |
|
262 | 262 | ' %s missing these columns: %s' % ( |
|
263 | 263 | label,', '.join(sorted(names)) |
|
264 | 264 | ) |
|
265 | 265 | ) |
|
266 | 266 | for name,cd in td.columns_different.items(): |
|
267 | 267 | out.append(' column with differences: %s' % name) |
|
268 | 268 | out.append(column_template % (self.labelA,cd.col_A)) |
|
269 | 269 | out.append(column_template % (self.labelB,cd.col_B)) |
|
270 | 270 | |
|
271 | 271 | if out: |
|
272 | 272 | out.insert(0, 'Schema diffs:') |
|
273 | 273 | return '\n'.join(out) |
|
274 | 274 | else: |
|
275 | 275 | return 'No schema diffs' |
|
276 | 276 | |
|
277 | 277 | def __len__(self): |
|
278 | 278 | """ |
|
279 | 279 | Used in bool evaluation, return of 0 means no diffs. |
|
280 | 280 | """ |
|
281 | 281 | return ( |
|
282 | 282 | len(self.tables_missing_from_A) + |
|
283 | 283 | len(self.tables_missing_from_B) + |
|
284 | 284 | len(self.tables_different) |
|
285 | 285 | ) |
@@ -1,6 +1,6 b'' | |||
|
1 | 1 | #!/usr/bin/env python |
|
2 | 2 | # -*- coding: utf-8 -*- |
|
3 | 3 | |
|
4 | from migrate.versioning.script.base import BaseScript | |
|
5 | from migrate.versioning.script.py import PythonScript | |
|
6 | from migrate.versioning.script.sql import SqlScript | |
|
4 | from rhodecode.lib.dbmigrate.migrate.versioning.script.base import BaseScript | |
|
5 | from rhodecode.lib.dbmigrate.migrate.versioning.script.py import PythonScript | |
|
6 | from rhodecode.lib.dbmigrate.migrate.versioning.script.sql import SqlScript |
@@ -1,57 +1,57 b'' | |||
|
1 | 1 | #!/usr/bin/env python |
|
2 | 2 | # -*- coding: utf-8 -*- |
|
3 | 3 | import logging |
|
4 | 4 | |
|
5 | from migrate import exceptions | |
|
6 | from migrate.versioning.config import operations | |
|
7 | from migrate.versioning import pathed | |
|
5 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
6 | from rhodecode.lib.dbmigrate.migrate.versioning.config import operations | |
|
7 | from rhodecode.lib.dbmigrate.migrate.versioning import pathed | |
|
8 | 8 | |
|
9 | 9 | |
|
10 | 10 | log = logging.getLogger(__name__) |
|
11 | 11 | |
|
12 | 12 | class BaseScript(pathed.Pathed): |
|
13 | 13 | """Base class for other types of scripts. |
|
14 | 14 | All scripts have the following properties: |
|
15 | 15 | |
|
16 | 16 | source (script.source()) |
|
17 | 17 | The source code of the script |
|
18 | 18 | version (script.version()) |
|
19 | 19 | The version number of the script |
|
20 | 20 | operations (script.operations()) |
|
21 | 21 | The operations defined by the script: upgrade(), downgrade() or both. |
|
22 | 22 | Returns a tuple of operations. |
|
23 | 23 | Can also check for an operation with ex. script.operation(Script.ops.up) |
|
24 | 24 | """ # TODO: sphinxfy this and implement it correctly |
|
25 | 25 | |
|
26 | 26 | def __init__(self, path): |
|
27 | 27 | log.debug('Loading script %s...' % path) |
|
28 | 28 | self.verify(path) |
|
29 | 29 | super(BaseScript, self).__init__(path) |
|
30 | 30 | log.debug('Script %s loaded successfully' % path) |
|
31 | 31 | |
|
32 | 32 | @classmethod |
|
33 | 33 | def verify(cls, path): |
|
34 | 34 | """Ensure this is a valid script |
|
35 | 35 | This version simply ensures the script file's existence |
|
36 | 36 | |
|
37 | 37 | :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>` |
|
38 | 38 | """ |
|
39 | 39 | try: |
|
40 | 40 | cls.require_found(path) |
|
41 | 41 | except: |
|
42 | 42 | raise exceptions.InvalidScriptError(path) |
|
43 | 43 | |
|
44 | 44 | def source(self): |
|
45 | 45 | """:returns: source code of the script. |
|
46 | 46 | :rtype: string |
|
47 | 47 | """ |
|
48 | 48 | fd = open(self.path) |
|
49 | 49 | ret = fd.read() |
|
50 | 50 | fd.close() |
|
51 | 51 | return ret |
|
52 | 52 | |
|
53 | 53 | def run(self, engine): |
|
54 | 54 | """Core of each BaseScript subclass. |
|
55 | 55 | This method executes the script. |
|
56 | 56 | """ |
|
57 | 57 | raise NotImplementedError() |
@@ -1,159 +1,159 b'' | |||
|
1 | 1 | #!/usr/bin/env python |
|
2 | 2 | # -*- coding: utf-8 -*- |
|
3 | 3 | |
|
4 | 4 | import shutil |
|
5 | 5 | import warnings |
|
6 | 6 | import logging |
|
7 | 7 | from StringIO import StringIO |
|
8 | 8 | |
|
9 | import migrate | |
|
10 | from migrate.versioning import genmodel, schemadiff | |
|
11 | from migrate.versioning.config import operations | |
|
12 | from migrate.versioning.template import Template | |
|
13 | from migrate.versioning.script import base | |
|
14 | from migrate.versioning.util import import_path, load_model, with_engine | |
|
15 | from migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError | |
|
9 | from rhodecode.lib.dbmigrate import migrate | |
|
10 | from rhodecode.lib.dbmigrate.migrate.versioning import genmodel, schemadiff | |
|
11 | from rhodecode.lib.dbmigrate.migrate.versioning.config import operations | |
|
12 | from rhodecode.lib.dbmigrate.migrate.versioning.template import Template | |
|
13 | from rhodecode.lib.dbmigrate.migrate.versioning.script import base | |
|
14 | from rhodecode.lib.dbmigrate.migrate.versioning.util import import_path, load_model, with_engine | |
|
15 | from rhodecode.lib.dbmigrate.migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError | |
|
16 | 16 | |
|
17 | 17 | log = logging.getLogger(__name__) |
|
18 | 18 | __all__ = ['PythonScript'] |
|
19 | 19 | |
|
20 | 20 | |
|
21 | 21 | class PythonScript(base.BaseScript): |
|
22 | 22 | """Base for Python scripts""" |
|
23 | 23 | |
|
24 | 24 | @classmethod |
|
25 | 25 | def create(cls, path, **opts): |
|
26 | 26 | """Create an empty migration script at specified path |
|
27 | 27 | |
|
28 | 28 | :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`""" |
|
29 | 29 | cls.require_notfound(path) |
|
30 | 30 | |
|
31 | 31 | src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None)) |
|
32 | 32 | shutil.copy(src, path) |
|
33 | 33 | |
|
34 | 34 | return cls(path) |
|
35 | 35 | |
|
36 | 36 | @classmethod |
|
37 | 37 | def make_update_script_for_model(cls, engine, oldmodel, |
|
38 | 38 | model, repository, **opts): |
|
39 | 39 | """Create a migration script based on difference between two SA models. |
|
40 | 40 | |
|
41 | 41 | :param repository: path to migrate repository |
|
42 | 42 | :param oldmodel: dotted.module.name:SAClass or SAClass object |
|
43 | 43 | :param model: dotted.module.name:SAClass or SAClass object |
|
44 | 44 | :param engine: SQLAlchemy engine |
|
45 | 45 | :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>` |
|
46 | 46 | :type oldmodel: string or Class |
|
47 | 47 | :type model: string or Class |
|
48 | 48 | :type engine: Engine instance |
|
49 | 49 | :returns: Upgrade / Downgrade script |
|
50 | 50 | :rtype: string |
|
51 | 51 | """ |
|
52 | ||
|
52 | ||
|
53 | 53 | if isinstance(repository, basestring): |
|
54 | 54 | # oh dear, an import cycle! |
|
55 | from migrate.versioning.repository import Repository | |
|
55 | from rhodecode.lib.dbmigrate.migrate.versioning.repository import Repository | |
|
56 | 56 | repository = Repository(repository) |
|
57 | 57 | |
|
58 | 58 | oldmodel = load_model(oldmodel) |
|
59 | 59 | model = load_model(model) |
|
60 | 60 | |
|
61 | 61 | # Compute differences. |
|
62 | 62 | diff = schemadiff.getDiffOfModelAgainstModel( |
|
63 | 63 | oldmodel, |
|
64 | 64 | model, |
|
65 | 65 | excludeTables=[repository.version_table]) |
|
66 | 66 | # TODO: diff can be False (there is no difference?) |
|
67 | 67 | decls, upgradeCommands, downgradeCommands = \ |
|
68 | genmodel.ModelGenerator(diff,engine).toUpgradeDowngradePython() | |
|
68 | genmodel.ModelGenerator(diff, engine).toUpgradeDowngradePython() | |
|
69 | 69 | |
|
70 | 70 | # Store differences into file. |
|
71 | 71 | src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None)) |
|
72 | 72 | f = open(src) |
|
73 | 73 | contents = f.read() |
|
74 | 74 | f.close() |
|
75 | 75 | |
|
76 | 76 | # generate source |
|
77 | 77 | search = 'def upgrade(migrate_engine):' |
|
78 | 78 | contents = contents.replace(search, '\n\n'.join((decls, search)), 1) |
|
79 | 79 | if upgradeCommands: |
|
80 | 80 | contents = contents.replace(' pass', upgradeCommands, 1) |
|
81 | 81 | if downgradeCommands: |
|
82 | 82 | contents = contents.replace(' pass', downgradeCommands, 1) |
|
83 | 83 | return contents |
|
84 | 84 | |
|
85 | 85 | @classmethod |
|
86 | 86 | def verify_module(cls, path): |
|
87 | 87 | """Ensure path is a valid script |
|
88 | 88 | |
|
89 | 89 | :param path: Script location |
|
90 | 90 | :type path: string |
|
91 | 91 | :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>` |
|
92 | 92 | :returns: Python module |
|
93 | 93 | """ |
|
94 | 94 | # Try to import and get the upgrade() func |
|
95 | 95 | module = import_path(path) |
|
96 | 96 | try: |
|
97 | 97 | assert callable(module.upgrade) |
|
98 | 98 | except Exception, e: |
|
99 | 99 | raise InvalidScriptError(path + ': %s' % str(e)) |
|
100 | 100 | return module |
|
101 | 101 | |
|
102 | 102 | def preview_sql(self, url, step, **args): |
|
103 | 103 | """Mocks SQLAlchemy Engine to store all executed calls in a string |
|
104 | 104 | and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>` |
|
105 | 105 | |
|
106 | 106 | :returns: SQL file |
|
107 | 107 | """ |
|
108 | 108 | buf = StringIO() |
|
109 | 109 | args['engine_arg_strategy'] = 'mock' |
|
110 | 110 | args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p) |
|
111 | 111 | |
|
112 | 112 | @with_engine |
|
113 | 113 | def go(url, step, **kw): |
|
114 | 114 | engine = kw.pop('engine') |
|
115 | 115 | self.run(engine, step) |
|
116 | 116 | return buf.getvalue() |
|
117 | 117 | |
|
118 | 118 | return go(url, step, **args) |
|
119 | 119 | |
|
120 | 120 | def run(self, engine, step): |
|
121 | 121 | """Core method of Script file. |
|
122 | 122 | Exectues :func:`update` or :func:`downgrade` functions |
|
123 | 123 | |
|
124 | 124 | :param engine: SQLAlchemy Engine |
|
125 | 125 | :param step: Operation to run |
|
126 | 126 | :type engine: string |
|
127 | 127 | :type step: int |
|
128 | 128 | """ |
|
129 | 129 | if step > 0: |
|
130 | 130 | op = 'upgrade' |
|
131 | 131 | elif step < 0: |
|
132 | 132 | op = 'downgrade' |
|
133 | 133 | else: |
|
134 | 134 | raise ScriptError("%d is not a valid step" % step) |
|
135 | 135 | |
|
136 | 136 | funcname = base.operations[op] |
|
137 | 137 | script_func = self._func(funcname) |
|
138 | 138 | |
|
139 | 139 | try: |
|
140 | 140 | script_func(engine) |
|
141 | 141 | except TypeError: |
|
142 | 142 | warnings.warn("upgrade/downgrade functions must accept engine" |
|
143 | 143 | " parameter (since version > 0.5.4)", MigrateDeprecationWarning) |
|
144 | 144 | raise |
|
145 | 145 | |
|
146 | 146 | @property |
|
147 | 147 | def module(self): |
|
148 | 148 | """Calls :meth:`migrate.versioning.script.py.verify_module` |
|
149 | 149 | and returns it. |
|
150 | 150 | """ |
|
151 | 151 | if not hasattr(self, '_module'): |
|
152 | 152 | self._module = self.verify_module(self.path) |
|
153 | 153 | return self._module |
|
154 | 154 | |
|
155 | 155 | def _func(self, funcname): |
|
156 | 156 | if not hasattr(self.module, funcname): |
|
157 | 157 | msg = "Function '%s' is not defined in this script" |
|
158 | 158 | raise ScriptError(msg % funcname) |
|
159 | 159 | return getattr(self.module, funcname) |
@@ -1,49 +1,49 b'' | |||
|
1 | 1 | #!/usr/bin/env python |
|
2 | 2 | # -*- coding: utf-8 -*- |
|
3 | 3 | import logging |
|
4 | 4 | import shutil |
|
5 | 5 | |
|
6 | from migrate.versioning.script import base | |
|
7 | from migrate.versioning.template import Template | |
|
6 | from rhodecode.lib.dbmigrate.migrate.versioning.script import base | |
|
7 | from rhodecode.lib.dbmigrate.migrate.versioning.template import Template | |
|
8 | 8 | |
|
9 | 9 | |
|
10 | 10 | log = logging.getLogger(__name__) |
|
11 | 11 | |
|
12 | 12 | class SqlScript(base.BaseScript): |
|
13 | 13 | """A file containing plain SQL statements.""" |
|
14 | 14 | |
|
15 | 15 | @classmethod |
|
16 | 16 | def create(cls, path, **opts): |
|
17 | 17 | """Create an empty migration script at specified path |
|
18 | 18 | |
|
19 | 19 | :returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`""" |
|
20 | 20 | cls.require_notfound(path) |
|
21 | 21 | |
|
22 | 22 | src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None)) |
|
23 | 23 | shutil.copy(src, path) |
|
24 | 24 | return cls(path) |
|
25 | 25 | |
|
26 | 26 | # TODO: why is step parameter even here? |
|
27 | 27 | def run(self, engine, step=None, executemany=True): |
|
28 | 28 | """Runs SQL script through raw dbapi execute call""" |
|
29 | 29 | text = self.source() |
|
30 | 30 | # Don't rely on SA's autocommit here |
|
31 | 31 | # (SA uses .startswith to check if a commit is needed. What if script |
|
32 | 32 | # starts with a comment?) |
|
33 | 33 | conn = engine.connect() |
|
34 | 34 | try: |
|
35 | 35 | trans = conn.begin() |
|
36 | 36 | try: |
|
37 | 37 | # HACK: SQLite doesn't allow multiple statements through |
|
38 | 38 | # its execute() method, but it provides executescript() instead |
|
39 | 39 | dbapi = conn.engine.raw_connection() |
|
40 | 40 | if executemany and getattr(dbapi, 'executescript', None): |
|
41 | 41 | dbapi.executescript(text) |
|
42 | 42 | else: |
|
43 | 43 | conn.execute(text) |
|
44 | 44 | trans.commit() |
|
45 | 45 | except: |
|
46 | 46 | trans.rollback() |
|
47 | 47 | raise |
|
48 | 48 | finally: |
|
49 | 49 | conn.close() |
@@ -1,215 +1,215 b'' | |||
|
1 | 1 | #!/usr/bin/env python |
|
2 | 2 | # -*- coding: utf-8 -*- |
|
3 | 3 | |
|
4 | 4 | """The migrate command-line tool.""" |
|
5 | 5 | |
|
6 | 6 | import sys |
|
7 | 7 | import inspect |
|
8 | 8 | import logging |
|
9 | 9 | from optparse import OptionParser, BadOptionError |
|
10 | 10 | |
|
11 | from migrate import exceptions | |
|
12 | from migrate.versioning import api | |
|
13 | from migrate.versioning.config import * | |
|
14 | from migrate.versioning.util import asbool | |
|
11 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
12 | from rhodecode.lib.dbmigrate.migrate.versioning import api | |
|
13 | from rhodecode.lib.dbmigrate.migrate.versioning.config import * | |
|
14 | from rhodecode.lib.dbmigrate.migrate.versioning.util import asbool | |
|
15 | 15 | |
|
16 | 16 | |
|
17 | 17 | alias = dict( |
|
18 | 18 | s=api.script, |
|
19 | 19 | vc=api.version_control, |
|
20 | 20 | dbv=api.db_version, |
|
21 | 21 | v=api.version, |
|
22 | 22 | ) |
|
23 | 23 | |
|
24 | 24 | def alias_setup(): |
|
25 | 25 | global alias |
|
26 | 26 | for key, val in alias.iteritems(): |
|
27 | 27 | setattr(api, key, val) |
|
28 | 28 | alias_setup() |
|
29 | 29 | |
|
30 | 30 | |
|
31 | 31 | class PassiveOptionParser(OptionParser): |
|
32 | 32 | |
|
33 | 33 | def _process_args(self, largs, rargs, values): |
|
34 | 34 | """little hack to support all --some_option=value parameters""" |
|
35 | 35 | |
|
36 | 36 | while rargs: |
|
37 | 37 | arg = rargs[0] |
|
38 | 38 | if arg == "--": |
|
39 | 39 | del rargs[0] |
|
40 | 40 | return |
|
41 | 41 | elif arg[0:2] == "--": |
|
42 | 42 | # if parser does not know about the option |
|
43 | 43 | # pass it along (make it anonymous) |
|
44 | 44 | try: |
|
45 | 45 | opt = arg.split('=', 1)[0] |
|
46 | 46 | self._match_long_opt(opt) |
|
47 | 47 | except BadOptionError: |
|
48 | 48 | largs.append(arg) |
|
49 | 49 | del rargs[0] |
|
50 | 50 | else: |
|
51 | 51 | self._process_long_opt(rargs, values) |
|
52 | 52 | elif arg[:1] == "-" and len(arg) > 1: |
|
53 | 53 | self._process_short_opts(rargs, values) |
|
54 | 54 | elif self.allow_interspersed_args: |
|
55 | 55 | largs.append(arg) |
|
56 | 56 | del rargs[0] |
|
57 | 57 | |
|
58 | 58 | def main(argv=None, **kwargs): |
|
59 | 59 | """Shell interface to :mod:`migrate.versioning.api`. |
|
60 | 60 | |
|
61 | 61 | kwargs are default options that can be overriden with passing |
|
62 | 62 | --some_option as command line option |
|
63 | 63 | |
|
64 | 64 | :param disable_logging: Let migrate configure logging |
|
65 | 65 | :type disable_logging: bool |
|
66 | 66 | """ |
|
67 | 67 | if argv is not None: |
|
68 | 68 | argv = argv |
|
69 | 69 | else: |
|
70 | 70 | argv = list(sys.argv[1:]) |
|
71 | 71 | commands = list(api.__all__) |
|
72 | 72 | commands.sort() |
|
73 | 73 | |
|
74 | 74 | usage = """%%prog COMMAND ... |
|
75 | 75 | |
|
76 | 76 | Available commands: |
|
77 | 77 | %s |
|
78 | 78 | |
|
79 | 79 | Enter "%%prog help COMMAND" for information on a particular command. |
|
80 | 80 | """ % '\n\t'.join(["%s - %s" % (command.ljust(28), |
|
81 | 81 | api.command_desc.get(command)) for command in commands]) |
|
82 | 82 | |
|
83 | 83 | parser = PassiveOptionParser(usage=usage) |
|
84 | 84 | parser.add_option("-d", "--debug", |
|
85 | 85 | action="store_true", |
|
86 | 86 | dest="debug", |
|
87 | 87 | default=False, |
|
88 | 88 | help="Shortcut to turn on DEBUG mode for logging") |
|
89 | 89 | parser.add_option("-q", "--disable_logging", |
|
90 | 90 | action="store_true", |
|
91 | 91 | dest="disable_logging", |
|
92 | 92 | default=False, |
|
93 | 93 | help="Use this option to disable logging configuration") |
|
94 | 94 | help_commands = ['help', '-h', '--help'] |
|
95 | 95 | HELP = False |
|
96 | 96 | |
|
97 | 97 | try: |
|
98 | 98 | command = argv.pop(0) |
|
99 | 99 | if command in help_commands: |
|
100 | 100 | HELP = True |
|
101 | 101 | command = argv.pop(0) |
|
102 | 102 | except IndexError: |
|
103 | 103 | parser.print_help() |
|
104 | 104 | return |
|
105 | 105 | |
|
106 | 106 | command_func = getattr(api, command, None) |
|
107 | 107 | if command_func is None or command.startswith('_'): |
|
108 | 108 | parser.error("Invalid command %s" % command) |
|
109 | 109 | |
|
110 | 110 | parser.set_usage(inspect.getdoc(command_func)) |
|
111 | 111 | f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func) |
|
112 | 112 | for arg in f_args: |
|
113 | 113 | parser.add_option( |
|
114 | 114 | "--%s" % arg, |
|
115 | 115 | dest=arg, |
|
116 | 116 | action='store', |
|
117 | 117 | type="string") |
|
118 | 118 | |
|
119 | 119 | # display help of the current command |
|
120 | 120 | if HELP: |
|
121 | 121 | parser.print_help() |
|
122 | 122 | return |
|
123 | 123 | |
|
124 | 124 | options, args = parser.parse_args(argv) |
|
125 | 125 | |
|
126 | 126 | # override kwargs with anonymous parameters |
|
127 | 127 | override_kwargs = dict() |
|
128 | 128 | for arg in list(args): |
|
129 | 129 | if arg.startswith('--'): |
|
130 | 130 | args.remove(arg) |
|
131 | 131 | if '=' in arg: |
|
132 | 132 | opt, value = arg[2:].split('=', 1) |
|
133 | 133 | else: |
|
134 | 134 | opt = arg[2:] |
|
135 | 135 | value = True |
|
136 | 136 | override_kwargs[opt] = value |
|
137 | 137 | |
|
138 | 138 | # override kwargs with options if user is overwriting |
|
139 | 139 | for key, value in options.__dict__.iteritems(): |
|
140 | 140 | if value is not None: |
|
141 | 141 | override_kwargs[key] = value |
|
142 | 142 | |
|
143 | 143 | # arguments that function accepts without passed kwargs |
|
144 | 144 | f_required = list(f_args) |
|
145 | 145 | candidates = dict(kwargs) |
|
146 | 146 | candidates.update(override_kwargs) |
|
147 | 147 | for key, value in candidates.iteritems(): |
|
148 | 148 | if key in f_args: |
|
149 | 149 | f_required.remove(key) |
|
150 | 150 | |
|
151 | 151 | # map function arguments to parsed arguments |
|
152 | 152 | for arg in args: |
|
153 | 153 | try: |
|
154 | 154 | kw = f_required.pop(0) |
|
155 | 155 | except IndexError: |
|
156 | 156 | parser.error("Too many arguments for command %s: %s" % (command, |
|
157 | 157 | arg)) |
|
158 | 158 | kwargs[kw] = arg |
|
159 | 159 | |
|
160 | 160 | # apply overrides |
|
161 | 161 | kwargs.update(override_kwargs) |
|
162 | 162 | |
|
163 | 163 | # configure options |
|
164 | 164 | for key, value in options.__dict__.iteritems(): |
|
165 | 165 | kwargs.setdefault(key, value) |
|
166 | 166 | |
|
167 | 167 | # configure logging |
|
168 | 168 | if not asbool(kwargs.pop('disable_logging', False)): |
|
169 | 169 | # filter to log =< INFO into stdout and rest to stderr |
|
170 | 170 | class SingleLevelFilter(logging.Filter): |
|
171 | 171 | def __init__(self, min=None, max=None): |
|
172 | 172 | self.min = min or 0 |
|
173 | 173 | self.max = max or 100 |
|
174 | 174 | |
|
175 | 175 | def filter(self, record): |
|
176 | 176 | return self.min <= record.levelno <= self.max |
|
177 | 177 | |
|
178 | 178 | logger = logging.getLogger() |
|
179 | 179 | h1 = logging.StreamHandler(sys.stdout) |
|
180 | 180 | f1 = SingleLevelFilter(max=logging.INFO) |
|
181 | 181 | h1.addFilter(f1) |
|
182 | 182 | h2 = logging.StreamHandler(sys.stderr) |
|
183 | 183 | f2 = SingleLevelFilter(min=logging.WARN) |
|
184 | 184 | h2.addFilter(f2) |
|
185 | 185 | logger.addHandler(h1) |
|
186 | 186 | logger.addHandler(h2) |
|
187 | 187 | |
|
188 | 188 | if options.debug: |
|
189 | 189 | logger.setLevel(logging.DEBUG) |
|
190 | 190 | else: |
|
191 | 191 | logger.setLevel(logging.INFO) |
|
192 | 192 | |
|
193 | 193 | log = logging.getLogger(__name__) |
|
194 | 194 | |
|
195 | 195 | # check if all args are given |
|
196 | 196 | try: |
|
197 | 197 | num_defaults = len(f_defaults) |
|
198 | 198 | except TypeError: |
|
199 | 199 | num_defaults = 0 |
|
200 | 200 | f_args_default = f_args[len(f_args) - num_defaults:] |
|
201 | 201 | required = list(set(f_required) - set(f_args_default)) |
|
202 | 202 | if required: |
|
203 | 203 | parser.error("Not enough arguments for command %s: %s not specified" \ |
|
204 | 204 | % (command, ', '.join(required))) |
|
205 | 205 | |
|
206 | 206 | # handle command |
|
207 | 207 | try: |
|
208 | 208 | ret = command_func(**kwargs) |
|
209 | 209 | if ret is not None: |
|
210 | 210 | log.info(ret) |
|
211 | 211 | except (exceptions.UsageError, exceptions.KnownError), e: |
|
212 | 212 | parser.error(e.args[0]) |
|
213 | 213 | |
|
214 | 214 | if __name__ == "__main__": |
|
215 | 215 | main() |
@@ -1,94 +1,94 b'' | |||
|
1 | 1 | #!/usr/bin/env python |
|
2 | 2 | # -*- coding: utf-8 -*- |
|
3 | 3 | |
|
4 | 4 | import os |
|
5 | 5 | import shutil |
|
6 | 6 | import sys |
|
7 | 7 | |
|
8 | 8 | from pkg_resources import resource_filename |
|
9 | 9 | |
|
10 | from migrate.versioning.config import * | |
|
11 | from migrate.versioning import pathed | |
|
10 | from rhodecode.lib.dbmigrate.migrate.versioning.config import * | |
|
11 | from rhodecode.lib.dbmigrate.migrate.versioning import pathed | |
|
12 | 12 | |
|
13 | 13 | |
|
14 | 14 | class Collection(pathed.Pathed): |
|
15 | 15 | """A collection of templates of a specific type""" |
|
16 | 16 | _mask = None |
|
17 | 17 | |
|
18 | 18 | def get_path(self, file): |
|
19 | 19 | return os.path.join(self.path, str(file)) |
|
20 | 20 | |
|
21 | 21 | |
|
22 | 22 | class RepositoryCollection(Collection): |
|
23 | 23 | _mask = '%s' |
|
24 | 24 | |
|
25 | 25 | class ScriptCollection(Collection): |
|
26 | 26 | _mask = '%s.py_tmpl' |
|
27 | 27 | |
|
28 | 28 | class ManageCollection(Collection): |
|
29 | 29 | _mask = '%s.py_tmpl' |
|
30 | 30 | |
|
31 | 31 | class SQLScriptCollection(Collection): |
|
32 | 32 | _mask = '%s.py_tmpl' |
|
33 | 33 | |
|
34 | 34 | class Template(pathed.Pathed): |
|
35 | 35 | """Finds the paths/packages of various Migrate templates. |
|
36 | 36 | |
|
37 | :param path: Templates are loaded from migrate package | |
|
37 | :param path: Templates are loaded from rhodecode.lib.dbmigrate.migrate package | |
|
38 | 38 | if `path` is not provided. |
|
39 | 39 | """ |
|
40 | 40 | pkg = 'migrate.versioning.templates' |
|
41 | 41 | _manage = 'manage.py_tmpl' |
|
42 | 42 | |
|
43 | 43 | def __new__(cls, path=None): |
|
44 | 44 | if path is None: |
|
45 | 45 | path = cls._find_path(cls.pkg) |
|
46 | 46 | return super(Template, cls).__new__(cls, path) |
|
47 | 47 | |
|
48 | 48 | def __init__(self, path=None): |
|
49 | 49 | if path is None: |
|
50 | 50 | path = Template._find_path(self.pkg) |
|
51 | 51 | super(Template, self).__init__(path) |
|
52 | 52 | self.repository = RepositoryCollection(os.path.join(path, 'repository')) |
|
53 | 53 | self.script = ScriptCollection(os.path.join(path, 'script')) |
|
54 | 54 | self.manage = ManageCollection(os.path.join(path, 'manage')) |
|
55 | 55 | self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script')) |
|
56 | 56 | |
|
57 | 57 | @classmethod |
|
58 | 58 | def _find_path(cls, pkg): |
|
59 | 59 | """Returns absolute path to dotted python package.""" |
|
60 | 60 | tmp_pkg = pkg.rsplit('.', 1) |
|
61 | 61 | |
|
62 | 62 | if len(tmp_pkg) != 1: |
|
63 | 63 | return resource_filename(tmp_pkg[0], tmp_pkg[1]) |
|
64 | 64 | else: |
|
65 | 65 | return resource_filename(tmp_pkg[0], '') |
|
66 | 66 | |
|
67 | 67 | def _get_item(self, collection, theme=None): |
|
68 | 68 | """Locates and returns collection. |
|
69 | 69 | |
|
70 | 70 | :param collection: name of collection to locate |
|
71 | 71 | :param type_: type of subfolder in collection (defaults to "_default") |
|
72 | 72 | :returns: (package, source) |
|
73 | 73 | :rtype: str, str |
|
74 | 74 | """ |
|
75 | 75 | item = getattr(self, collection) |
|
76 | 76 | theme_mask = getattr(item, '_mask') |
|
77 | 77 | theme = theme_mask % (theme or 'default') |
|
78 | 78 | return item.get_path(theme) |
|
79 | 79 | |
|
80 | 80 | def get_repository(self, *a, **kw): |
|
81 | 81 | """Calls self._get_item('repository', *a, **kw)""" |
|
82 | 82 | return self._get_item('repository', *a, **kw) |
|
83 | 83 | |
|
84 | 84 | def get_script(self, *a, **kw): |
|
85 | 85 | """Calls self._get_item('script', *a, **kw)""" |
|
86 | 86 | return self._get_item('script', *a, **kw) |
|
87 | 87 | |
|
88 | 88 | def get_sql_script(self, *a, **kw): |
|
89 | 89 | """Calls self._get_item('sql_script', *a, **kw)""" |
|
90 | 90 | return self._get_item('sql_script', *a, **kw) |
|
91 | 91 | |
|
92 | 92 | def get_manage(self, *a, **kw): |
|
93 | 93 | """Calls self._get_item('manage', *a, **kw)""" |
|
94 | 94 | return self._get_item('manage', *a, **kw) |
@@ -1,179 +1,179 b'' | |||
|
1 | 1 | #!/usr/bin/env python |
|
2 | 2 | # -*- coding: utf-8 -*- |
|
3 | 3 | """.. currentmodule:: migrate.versioning.util""" |
|
4 | 4 | |
|
5 | 5 | import warnings |
|
6 | 6 | import logging |
|
7 | 7 | from decorator import decorator |
|
8 | 8 | from pkg_resources import EntryPoint |
|
9 | 9 | |
|
10 | 10 | from sqlalchemy import create_engine |
|
11 | 11 | from sqlalchemy.engine import Engine |
|
12 | 12 | from sqlalchemy.pool import StaticPool |
|
13 | 13 | |
|
14 | from migrate import exceptions | |
|
15 | from migrate.versioning.util.keyedinstance import KeyedInstance | |
|
16 | from migrate.versioning.util.importpath import import_path | |
|
14 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
15 | from rhodecode.lib.dbmigrate.migrate.versioning.util.keyedinstance import KeyedInstance | |
|
16 | from rhodecode.lib.dbmigrate.migrate.versioning.util.importpath import import_path | |
|
17 | 17 | |
|
18 | 18 | |
|
19 | 19 | log = logging.getLogger(__name__) |
|
20 | 20 | |
|
21 | 21 | def load_model(dotted_name): |
|
22 | 22 | """Import module and use module-level variable". |
|
23 | 23 | |
|
24 | 24 | :param dotted_name: path to model in form of string: ``some.python.module:Class`` |
|
25 | 25 | |
|
26 | 26 | .. versionchanged:: 0.5.4 |
|
27 | 27 | |
|
28 | 28 | """ |
|
29 | 29 | if isinstance(dotted_name, basestring): |
|
30 | 30 | if ':' not in dotted_name: |
|
31 | 31 | # backwards compatibility |
|
32 | 32 | warnings.warn('model should be in form of module.model:User ' |
|
33 | 33 | 'and not module.model.User', exceptions.MigrateDeprecationWarning) |
|
34 | 34 | dotted_name = ':'.join(dotted_name.rsplit('.', 1)) |
|
35 | 35 | return EntryPoint.parse('x=%s' % dotted_name).load(False) |
|
36 | 36 | else: |
|
37 | 37 | # Assume it's already loaded. |
|
38 | 38 | return dotted_name |
|
39 | 39 | |
|
40 | 40 | def asbool(obj): |
|
41 | 41 | """Do everything to use object as bool""" |
|
42 | 42 | if isinstance(obj, basestring): |
|
43 | 43 | obj = obj.strip().lower() |
|
44 | 44 | if obj in ['true', 'yes', 'on', 'y', 't', '1']: |
|
45 | 45 | return True |
|
46 | 46 | elif obj in ['false', 'no', 'off', 'n', 'f', '0']: |
|
47 | 47 | return False |
|
48 | 48 | else: |
|
49 | 49 | raise ValueError("String is not true/false: %r" % obj) |
|
50 | 50 | if obj in (True, False): |
|
51 | 51 | return bool(obj) |
|
52 | 52 | else: |
|
53 | 53 | raise ValueError("String is not true/false: %r" % obj) |
|
54 | 54 | |
|
55 | 55 | def guess_obj_type(obj): |
|
56 | 56 | """Do everything to guess object type from string |
|
57 | 57 | |
|
58 | 58 | Tries to convert to `int`, `bool` and finally returns if not succeded. |
|
59 | 59 | |
|
60 | 60 | .. versionadded: 0.5.4 |
|
61 | 61 | """ |
|
62 | 62 | |
|
63 | 63 | result = None |
|
64 | 64 | |
|
65 | 65 | try: |
|
66 | 66 | result = int(obj) |
|
67 | 67 | except: |
|
68 | 68 | pass |
|
69 | 69 | |
|
70 | 70 | if result is None: |
|
71 | 71 | try: |
|
72 | 72 | result = asbool(obj) |
|
73 | 73 | except: |
|
74 | 74 | pass |
|
75 | 75 | |
|
76 | 76 | if result is not None: |
|
77 | 77 | return result |
|
78 | 78 | else: |
|
79 | 79 | return obj |
|
80 | 80 | |
|
81 | 81 | @decorator |
|
82 | 82 | def catch_known_errors(f, *a, **kw): |
|
83 | 83 | """Decorator that catches known api errors |
|
84 | 84 | |
|
85 | 85 | .. versionadded: 0.5.4 |
|
86 | 86 | """ |
|
87 | 87 | |
|
88 | 88 | try: |
|
89 | 89 | return f(*a, **kw) |
|
90 | 90 | except exceptions.PathFoundError, e: |
|
91 | 91 | raise exceptions.KnownError("The path %s already exists" % e.args[0]) |
|
92 | 92 | |
|
93 | 93 | def construct_engine(engine, **opts): |
|
94 | 94 | """.. versionadded:: 0.5.4 |
|
95 | 95 | |
|
96 | 96 | Constructs and returns SQLAlchemy engine. |
|
97 | 97 | |
|
98 | 98 | Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions: |
|
99 | 99 | |
|
100 | 100 | :param engine: connection string or a existing engine |
|
101 | 101 | :param engine_dict: python dictionary of options to pass to `create_engine` |
|
102 | 102 | :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`) |
|
103 | 103 | :type engine_dict: dict |
|
104 | 104 | :type engine: string or Engine instance |
|
105 | 105 | :type engine_arg_*: string |
|
106 | 106 | :returns: SQLAlchemy Engine |
|
107 | 107 | |
|
108 | 108 | .. note:: |
|
109 | 109 | |
|
110 | 110 | keyword parameters override ``engine_dict`` values. |
|
111 | 111 | |
|
112 | 112 | """ |
|
113 | 113 | if isinstance(engine, Engine): |
|
114 | 114 | return engine |
|
115 | 115 | elif not isinstance(engine, basestring): |
|
116 | 116 | raise ValueError("you need to pass either an existing engine or a database uri") |
|
117 | 117 | |
|
118 | 118 | # get options for create_engine |
|
119 | 119 | if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict): |
|
120 | 120 | kwargs = opts['engine_dict'] |
|
121 | 121 | else: |
|
122 | 122 | kwargs = dict() |
|
123 | 123 | |
|
124 | 124 | # DEPRECATED: handle echo the old way |
|
125 | 125 | echo = asbool(opts.get('echo', False)) |
|
126 | 126 | if echo: |
|
127 | 127 | warnings.warn('echo=True parameter is deprecated, pass ' |
|
128 | 128 | 'engine_arg_echo=True or engine_dict={"echo": True}', |
|
129 | 129 | exceptions.MigrateDeprecationWarning) |
|
130 | 130 | kwargs['echo'] = echo |
|
131 | 131 | |
|
132 | 132 | # parse keyword arguments |
|
133 | 133 | for key, value in opts.iteritems(): |
|
134 | 134 | if key.startswith('engine_arg_'): |
|
135 | 135 | kwargs[key[11:]] = guess_obj_type(value) |
|
136 | 136 | |
|
137 | 137 | log.debug('Constructing engine') |
|
138 | 138 | # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs) |
|
139 | 139 | # seems like 0.5.x branch does not work with engine.dispose and staticpool |
|
140 | 140 | return create_engine(engine, **kwargs) |
|
141 | 141 | |
|
142 | 142 | @decorator |
|
143 | 143 | def with_engine(f, *a, **kw): |
|
144 | 144 | """Decorator for :mod:`migrate.versioning.api` functions |
|
145 | 145 | to safely close resources after function usage. |
|
146 | 146 | |
|
147 | 147 | Passes engine parameters to :func:`construct_engine` and |
|
148 | 148 | resulting parameter is available as kw['engine']. |
|
149 | 149 | |
|
150 | 150 | Engine is disposed after wrapped function is executed. |
|
151 | 151 | |
|
152 | 152 | .. versionadded: 0.6.0 |
|
153 | 153 | """ |
|
154 | 154 | url = a[0] |
|
155 | 155 | engine = construct_engine(url, **kw) |
|
156 | 156 | |
|
157 | 157 | try: |
|
158 | 158 | kw['engine'] = engine |
|
159 | 159 | return f(*a, **kw) |
|
160 | 160 | finally: |
|
161 | 161 | if isinstance(engine, Engine): |
|
162 | 162 | log.debug('Disposing SQLAlchemy engine %s', engine) |
|
163 | 163 | engine.dispose() |
|
164 | 164 | |
|
165 | 165 | |
|
166 | 166 | class Memoize: |
|
167 | 167 | """Memoize(fn) - an instance which acts like fn but memoizes its arguments |
|
168 | 168 | Will only work on functions with non-mutable arguments |
|
169 | 169 | |
|
170 | 170 | ActiveState Code 52201 |
|
171 | 171 | """ |
|
172 | 172 | def __init__(self, fn): |
|
173 | 173 | self.fn = fn |
|
174 | 174 | self.memo = {} |
|
175 | 175 | |
|
176 | 176 | def __call__(self, *args): |
|
177 | 177 | if not self.memo.has_key(args): |
|
178 | 178 | self.memo[args] = self.fn(*args) |
|
179 | 179 | return self.memo[args] |
@@ -1,215 +1,215 b'' | |||
|
1 | 1 | #!/usr/bin/env python |
|
2 | 2 | # -*- coding: utf-8 -*- |
|
3 | 3 | |
|
4 | 4 | import os |
|
5 | 5 | import re |
|
6 | 6 | import shutil |
|
7 | 7 | import logging |
|
8 | 8 | |
|
9 | from migrate import exceptions | |
|
10 | from migrate.versioning import pathed, script | |
|
9 | from rhodecode.lib.dbmigrate.migrate import exceptions | |
|
10 | from rhodecode.lib.dbmigrate.migrate.versioning import pathed, script | |
|
11 | 11 | |
|
12 | 12 | |
|
13 | 13 | log = logging.getLogger(__name__) |
|
14 | 14 | |
|
15 | 15 | class VerNum(object): |
|
16 | 16 | """A version number that behaves like a string and int at the same time""" |
|
17 | 17 | |
|
18 | 18 | _instances = dict() |
|
19 | 19 | |
|
20 | 20 | def __new__(cls, value): |
|
21 | 21 | val = str(value) |
|
22 | 22 | if val not in cls._instances: |
|
23 | 23 | cls._instances[val] = super(VerNum, cls).__new__(cls) |
|
24 | 24 | ret = cls._instances[val] |
|
25 | 25 | return ret |
|
26 | 26 | |
|
27 | 27 | def __init__(self,value): |
|
28 | 28 | self.value = str(int(value)) |
|
29 | 29 | if self < 0: |
|
30 | 30 | raise ValueError("Version number cannot be negative") |
|
31 | 31 | |
|
32 | 32 | def __add__(self, value): |
|
33 | 33 | ret = int(self) + int(value) |
|
34 | 34 | return VerNum(ret) |
|
35 | 35 | |
|
36 | 36 | def __sub__(self, value): |
|
37 | 37 | return self + (int(value) * -1) |
|
38 | 38 | |
|
39 | 39 | def __cmp__(self, value): |
|
40 | 40 | return int(self) - int(value) |
|
41 | 41 | |
|
42 | 42 | def __repr__(self): |
|
43 | 43 | return "<VerNum(%s)>" % self.value |
|
44 | 44 | |
|
45 | 45 | def __str__(self): |
|
46 | 46 | return str(self.value) |
|
47 | 47 | |
|
48 | 48 | def __int__(self): |
|
49 | 49 | return int(self.value) |
|
50 | 50 | |
|
51 | 51 | |
|
52 | 52 | class Collection(pathed.Pathed): |
|
53 | 53 | """A collection of versioning scripts in a repository""" |
|
54 | 54 | |
|
55 | 55 | FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*') |
|
56 | 56 | |
|
57 | 57 | def __init__(self, path): |
|
58 | 58 | """Collect current version scripts in repository |
|
59 | 59 | and store them in self.versions |
|
60 | 60 | """ |
|
61 | 61 | super(Collection, self).__init__(path) |
|
62 | 62 | |
|
63 | 63 | # Create temporary list of files, allowing skipped version numbers. |
|
64 | 64 | files = os.listdir(path) |
|
65 | 65 | if '1' in files: |
|
66 | 66 | # deprecation |
|
67 | 67 | raise Exception('It looks like you have a repository in the old ' |
|
68 | 68 | 'format (with directories for each version). ' |
|
69 | 69 | 'Please convert repository before proceeding.') |
|
70 | 70 | |
|
71 | 71 | tempVersions = dict() |
|
72 | 72 | for filename in files: |
|
73 | 73 | match = self.FILENAME_WITH_VERSION.match(filename) |
|
74 | 74 | if match: |
|
75 | 75 | num = int(match.group(1)) |
|
76 | 76 | tempVersions.setdefault(num, []).append(filename) |
|
77 | 77 | else: |
|
78 | 78 | pass # Must be a helper file or something, let's ignore it. |
|
79 | 79 | |
|
80 | 80 | # Create the versions member where the keys |
|
81 | 81 | # are VerNum's and the values are Version's. |
|
82 | 82 | self.versions = dict() |
|
83 | 83 | for num, files in tempVersions.items(): |
|
84 | 84 | self.versions[VerNum(num)] = Version(num, path, files) |
|
85 | 85 | |
|
86 | 86 | @property |
|
87 | 87 | def latest(self): |
|
88 | 88 | """:returns: Latest version in Collection""" |
|
89 | 89 | return max([VerNum(0)] + self.versions.keys()) |
|
90 | 90 | |
|
91 | 91 | def create_new_python_version(self, description, **k): |
|
92 | 92 | """Create Python files for new version""" |
|
93 | 93 | ver = self.latest + 1 |
|
94 | 94 | extra = str_to_filename(description) |
|
95 | 95 | |
|
96 | 96 | if extra: |
|
97 | 97 | if extra == '_': |
|
98 | 98 | extra = '' |
|
99 | 99 | elif not extra.startswith('_'): |
|
100 | 100 | extra = '_%s' % extra |
|
101 | 101 | |
|
102 | 102 | filename = '%03d%s.py' % (ver, extra) |
|
103 | 103 | filepath = self._version_path(filename) |
|
104 | 104 | |
|
105 | 105 | script.PythonScript.create(filepath, **k) |
|
106 | 106 | self.versions[ver] = Version(ver, self.path, [filename]) |
|
107 | 107 | |
|
108 | 108 | def create_new_sql_version(self, database, **k): |
|
109 | 109 | """Create SQL files for new version""" |
|
110 | 110 | ver = self.latest + 1 |
|
111 | 111 | self.versions[ver] = Version(ver, self.path, []) |
|
112 | 112 | |
|
113 | 113 | # Create new files. |
|
114 | 114 | for op in ('upgrade', 'downgrade'): |
|
115 | 115 | filename = '%03d_%s_%s.sql' % (ver, database, op) |
|
116 | 116 | filepath = self._version_path(filename) |
|
117 | 117 | script.SqlScript.create(filepath, **k) |
|
118 | 118 | self.versions[ver].add_script(filepath) |
|
119 | 119 | |
|
120 | 120 | def version(self, vernum=None): |
|
121 | 121 | """Returns latest Version if vernum is not given. |
|
122 | 122 | Otherwise, returns wanted version""" |
|
123 | 123 | if vernum is None: |
|
124 | 124 | vernum = self.latest |
|
125 | 125 | return self.versions[VerNum(vernum)] |
|
126 | 126 | |
|
127 | 127 | @classmethod |
|
128 | 128 | def clear(cls): |
|
129 | 129 | super(Collection, cls).clear() |
|
130 | 130 | |
|
131 | 131 | def _version_path(self, ver): |
|
132 | 132 | """Returns path of file in versions repository""" |
|
133 | 133 | return os.path.join(self.path, str(ver)) |
|
134 | 134 | |
|
135 | 135 | |
|
136 | 136 | class Version(object): |
|
137 | 137 | """A single version in a collection |
|
138 | 138 | :param vernum: Version Number |
|
139 | 139 | :param path: Path to script files |
|
140 | 140 | :param filelist: List of scripts |
|
141 | 141 | :type vernum: int, VerNum |
|
142 | 142 | :type path: string |
|
143 | 143 | :type filelist: list |
|
144 | 144 | """ |
|
145 | 145 | |
|
146 | 146 | def __init__(self, vernum, path, filelist): |
|
147 | 147 | self.version = VerNum(vernum) |
|
148 | 148 | |
|
149 | 149 | # Collect scripts in this folder |
|
150 | 150 | self.sql = dict() |
|
151 | 151 | self.python = None |
|
152 | 152 | |
|
153 | 153 | for script in filelist: |
|
154 | 154 | self.add_script(os.path.join(path, script)) |
|
155 | 155 | |
|
156 | 156 | def script(self, database=None, operation=None): |
|
157 | 157 | """Returns SQL or Python Script""" |
|
158 | 158 | for db in (database, 'default'): |
|
159 | 159 | # Try to return a .sql script first |
|
160 | 160 | try: |
|
161 | 161 | return self.sql[db][operation] |
|
162 | 162 | except KeyError: |
|
163 | 163 | continue # No .sql script exists |
|
164 | 164 | |
|
165 | 165 | # TODO: maybe add force Python parameter? |
|
166 | 166 | ret = self.python |
|
167 | 167 | |
|
168 | 168 | assert ret is not None, \ |
|
169 | 169 | "There is no script for %d version" % self.version |
|
170 | 170 | return ret |
|
171 | 171 | |
|
172 | 172 | def add_script(self, path): |
|
173 | 173 | """Add script to Collection/Version""" |
|
174 | 174 | if path.endswith(Extensions.py): |
|
175 | 175 | self._add_script_py(path) |
|
176 | 176 | elif path.endswith(Extensions.sql): |
|
177 | 177 | self._add_script_sql(path) |
|
178 | 178 | |
|
179 | 179 | SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql') |
|
180 | 180 | |
|
181 | 181 | def _add_script_sql(self, path): |
|
182 | 182 | basename = os.path.basename(path) |
|
183 | 183 | match = self.SQL_FILENAME.match(basename) |
|
184 | 184 | |
|
185 | 185 | if match: |
|
186 | 186 | version, dbms, op = match.group(1), match.group(2), match.group(3) |
|
187 | 187 | else: |
|
188 | 188 | raise exceptions.ScriptError( |
|
189 | 189 | "Invalid SQL script name %s " % basename + \ |
|
190 | 190 | "(needs to be ###_database_operation.sql)") |
|
191 | 191 | |
|
192 | 192 | # File the script into a dictionary |
|
193 | 193 | self.sql.setdefault(dbms, {})[op] = script.SqlScript(path) |
|
194 | 194 | |
|
195 | 195 | def _add_script_py(self, path): |
|
196 | 196 | if self.python is not None: |
|
197 | 197 | raise exceptions.ScriptError('You can only have one Python script ' |
|
198 | 198 | 'per version, but you have: %s and %s' % (self.python, path)) |
|
199 | 199 | self.python = script.PythonScript(path) |
|
200 | 200 | |
|
201 | 201 | |
|
202 | 202 | class Extensions: |
|
203 | 203 | """A namespace for file extensions""" |
|
204 | 204 | py = 'py' |
|
205 | 205 | sql = 'sql' |
|
206 | 206 | |
|
207 | 207 | def str_to_filename(s): |
|
208 | 208 | """Replaces spaces, (double and single) quotes |
|
209 | 209 | and double underscores to underscores |
|
210 | 210 | """ |
|
211 | 211 | |
|
212 | 212 | s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_") |
|
213 | 213 | while '__' in s: |
|
214 | 214 | s = s.replace('__', '_') |
|
215 | 215 | return s |
General Comments 0
You need to be logged in to leave comments.
Login now