##// END OF EJS Templates
python3: 2to3 fixes
super-admin -
r4930:f88262ff default
parent child Browse files
Show More
@@ -1,1764 +1,1761 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 import collections
3 import collections
4 import copy
4 import copy
5 import datetime
5 import datetime
6 import hashlib
6 import hashlib
7 import hmac
7 import hmac
8 import json
8 import json
9 import logging
9 import logging
10 try:
10 import pickle
11 import cPickle as pickle
12 except ImportError:
13 import pickle
14 import sys
11 import sys
15 import threading
12 import threading
16 import time
13 import time
17 from xml.etree import ElementTree
14 from xml.etree import ElementTree
18
15
19 from authomatic.exceptions import (
16 from authomatic.exceptions import (
20 ConfigError,
17 ConfigError,
21 CredentialsError,
18 CredentialsError,
22 ImportStringError,
19 ImportStringError,
23 RequestElementsError,
20 RequestElementsError,
24 SessionError,
21 SessionError,
25 )
22 )
26 from authomatic import six
23 from authomatic import six
27 from authomatic.six.moves import urllib_parse as parse
24 from authomatic.six.moves import urllib_parse as parse
28
25
29
26
30 # =========================================================================
27 # =========================================================================
31 # Global variables !!!
28 # Global variables !!!
32 # =========================================================================
29 # =========================================================================
33
30
34 _logger = logging.getLogger(__name__)
31 _logger = logging.getLogger(__name__)
35 _logger.addHandler(logging.StreamHandler(sys.stdout))
32 _logger.addHandler(logging.StreamHandler(sys.stdout))
36
33
37 _counter = None
34 _counter = None
38
35
39
36
40 def normalize_dict(dict_):
37 def normalize_dict(dict_):
41 """
38 """
42 Replaces all values that are single-item iterables with the value of its
39 Replaces all values that are single-item iterables with the value of its
43 index 0.
40 index 0.
44
41
45 :param dict dict_:
42 :param dict dict_:
46 Dictionary to normalize.
43 Dictionary to normalize.
47
44
48 :returns:
45 :returns:
49 Normalized dictionary.
46 Normalized dictionary.
50
47
51 """
48 """
52
49
53 return dict([(k, v[0] if not isinstance(v, str) and len(v) == 1 else v)
50 return dict([(k, v[0] if not isinstance(v, str) and len(v) == 1 else v)
54 for k, v in list(dict_.items())])
51 for k, v in list(dict_.items())])
55
52
56
53
57 def items_to_dict(items):
54 def items_to_dict(items):
58 """
55 """
59 Converts list of tuples to dictionary with duplicate keys converted to
56 Converts list of tuples to dictionary with duplicate keys converted to
60 lists.
57 lists.
61
58
62 :param list items:
59 :param list items:
63 List of tuples.
60 List of tuples.
64
61
65 :returns:
62 :returns:
66 :class:`dict`
63 :class:`dict`
67
64
68 """
65 """
69
66
70 res = collections.defaultdict(list)
67 res = collections.defaultdict(list)
71
68
72 for k, v in items:
69 for k, v in items:
73 res[k].append(v)
70 res[k].append(v)
74
71
75 return normalize_dict(dict(res))
72 return normalize_dict(dict(res))
76
73
77
74
78 class Counter(object):
75 class Counter(object):
79 """
76 """
80 A simple counter to be used in the config to generate unique `id` values.
77 A simple counter to be used in the config to generate unique `id` values.
81 """
78 """
82
79
83 def __init__(self, start=0):
80 def __init__(self, start=0):
84 self._count = start
81 self._count = start
85
82
86 def count(self):
83 def count(self):
87 self._count += 1
84 self._count += 1
88 return self._count
85 return self._count
89
86
90
87
91 _counter = Counter()
88 _counter = Counter()
92
89
93
90
94 def provider_id():
91 def provider_id():
95 """
92 """
96 A simple counter to be used in the config to generate unique `IDs`.
93 A simple counter to be used in the config to generate unique `IDs`.
97
94
98 :returns:
95 :returns:
99 :class:`int`.
96 :class:`int`.
100
97
101 Use it in the :doc:`config` like this:
98 Use it in the :doc:`config` like this:
102 ::
99 ::
103
100
104 import authomatic
101 import authomatic
105
102
106 CONFIG = {
103 CONFIG = {
107 'facebook': {
104 'facebook': {
108 'class_': authomatic.providers.oauth2.Facebook,
105 'class_': authomatic.providers.oauth2.Facebook,
109 'id': authomatic.provider_id(), # returns 1
106 'id': authomatic.provider_id(), # returns 1
110 'consumer_key': '##########',
107 'consumer_key': '##########',
111 'consumer_secret': '##########',
108 'consumer_secret': '##########',
112 'scope': ['user_about_me', 'email']
109 'scope': ['user_about_me', 'email']
113 },
110 },
114 'google': {
111 'google': {
115 'class_': 'authomatic.providers.oauth2.Google',
112 'class_': 'authomatic.providers.oauth2.Google',
116 'id': authomatic.provider_id(), # returns 2
113 'id': authomatic.provider_id(), # returns 2
117 'consumer_key': '##########',
114 'consumer_key': '##########',
118 'consumer_secret': '##########',
115 'consumer_secret': '##########',
119 'scope': ['https://www.googleapis.com/auth/userinfo.profile',
116 'scope': ['https://www.googleapis.com/auth/userinfo.profile',
120 'https://www.googleapis.com/auth/userinfo.email']
117 'https://www.googleapis.com/auth/userinfo.email']
121 },
118 },
122 'windows_live': {
119 'windows_live': {
123 'class_': 'oauth2.WindowsLive',
120 'class_': 'oauth2.WindowsLive',
124 'id': authomatic.provider_id(), # returns 3
121 'id': authomatic.provider_id(), # returns 3
125 'consumer_key': '##########',
122 'consumer_key': '##########',
126 'consumer_secret': '##########',
123 'consumer_secret': '##########',
127 'scope': ['wl.basic', 'wl.emails', 'wl.photos']
124 'scope': ['wl.basic', 'wl.emails', 'wl.photos']
128 },
125 },
129 }
126 }
130
127
131 """
128 """
132
129
133 return _counter.count()
130 return _counter.count()
134
131
135
132
136 def escape(s):
133 def escape(s):
137 """
134 """
138 Escape a URL including any /.
135 Escape a URL including any /.
139 """
136 """
140 return parse.quote(s.encode('utf-8'), safe='~')
137 return parse.quote(s.encode('utf-8'), safe='~')
141
138
142
139
143 def json_qs_parser(body):
140 def json_qs_parser(body):
144 """
141 """
145 Parses response body from JSON, XML or query string.
142 Parses response body from JSON, XML or query string.
146
143
147 :param body:
144 :param body:
148 string
145 string
149
146
150 :returns:
147 :returns:
151 :class:`dict`, :class:`list` if input is JSON or query string,
148 :class:`dict`, :class:`list` if input is JSON or query string,
152 :class:`xml.etree.ElementTree.Element` if XML.
149 :class:`xml.etree.ElementTree.Element` if XML.
153
150
154 """
151 """
155 try:
152 try:
156 # Try JSON first.
153 # Try JSON first.
157 return json.loads(body)
154 return json.loads(body)
158 except (OverflowError, TypeError, ValueError):
155 except (OverflowError, TypeError, ValueError):
159 pass
156 pass
160
157
161 try:
158 try:
162 # Then XML.
159 # Then XML.
163 return ElementTree.fromstring(body)
160 return ElementTree.fromstring(body)
164 except (ElementTree.ParseError, TypeError, ValueError):
161 except (ElementTree.ParseError, TypeError, ValueError):
165 pass
162 pass
166
163
167 # Finally query string.
164 # Finally query string.
168 return dict(parse.parse_qsl(body))
165 return dict(parse.parse_qsl(body))
169
166
170
167
171 def import_string(import_name, silent=False):
168 def import_string(import_name, silent=False):
172 """
169 """
173 Imports an object by string in dotted notation.
170 Imports an object by string in dotted notation.
174
171
175 taken `from webapp2.import_string() <http://webapp-
172 taken `from webapp2.import_string() <http://webapp-
176 improved.appspot.com/api/webapp2.html#webapp2.import_string>`_
173 improved.appspot.com/api/webapp2.html#webapp2.import_string>`_
177
174
178 """
175 """
179
176
180 try:
177 try:
181 if '.' in import_name:
178 if '.' in import_name:
182 module, obj = import_name.rsplit('.', 1)
179 module, obj = import_name.rsplit('.', 1)
183 return getattr(__import__(module, None, None, [obj]), obj)
180 return getattr(__import__(module, None, None, [obj]), obj)
184 else:
181 else:
185 return __import__(import_name)
182 return __import__(import_name)
186 except (ImportError, AttributeError) as e:
183 except (ImportError, AttributeError) as e:
187 if not silent:
184 if not silent:
188 raise ImportStringError('Import from string failed for path {0}'
185 raise ImportStringError('Import from string failed for path {0}'
189 .format(import_name), str(e))
186 .format(import_name), str(e))
190
187
191
188
192 def resolve_provider_class(class_):
189 def resolve_provider_class(class_):
193 """
190 """
194 Returns a provider class.
191 Returns a provider class.
195
192
196 :param class_name: :class:`string` or
193 :param class_name: :class:`string` or
197 :class:`authomatic.providers.BaseProvider` subclass.
194 :class:`authomatic.providers.BaseProvider` subclass.
198
195
199 """
196 """
200
197
201 if isinstance(class_, str):
198 if isinstance(class_, str):
202 # prepare path for authomatic.providers package
199 # prepare path for authomatic.providers package
203 path = '.'.join([__package__, 'providers', class_])
200 path = '.'.join([__package__, 'providers', class_])
204
201
205 # try to import class by string from providers module or by fully
202 # try to import class by string from providers module or by fully
206 # qualified path
203 # qualified path
207 return import_string(class_, True) or import_string(path)
204 return import_string(class_, True) or import_string(path)
208 else:
205 else:
209 return class_
206 return class_
210
207
211
208
212 def id_to_name(config, short_name):
209 def id_to_name(config, short_name):
213 """
210 """
214 Returns the provider :doc:`config` key based on it's ``id`` value.
211 Returns the provider :doc:`config` key based on it's ``id`` value.
215
212
216 :param dict config:
213 :param dict config:
217 :doc:`config`.
214 :doc:`config`.
218 :param id:
215 :param id:
219 Value of the id parameter in the :ref:`config` to search for.
216 Value of the id parameter in the :ref:`config` to search for.
220
217
221 """
218 """
222
219
223 for k, v in list(config.items()):
220 for k, v in list(config.items()):
224 if v.get('id') == short_name:
221 if v.get('id') == short_name:
225 return k
222 return k
226
223
227 raise Exception(
224 raise Exception(
228 'No provider with id={0} found in the config!'.format(short_name))
225 'No provider with id={0} found in the config!'.format(short_name))
229
226
230
227
231 class ReprMixin(object):
228 class ReprMixin(object):
232 """
229 """
233 Provides __repr__() method with output *ClassName(arg1=value, arg2=value)*.
230 Provides __repr__() method with output *ClassName(arg1=value, arg2=value)*.
234
231
235 Ignored are attributes
232 Ignored are attributes
236
233
237 * which values are considered false.
234 * which values are considered false.
238 * with leading underscore.
235 * with leading underscore.
239 * listed in _repr_ignore.
236 * listed in _repr_ignore.
240
237
241 Values of attributes listed in _repr_sensitive will be replaced by *###*.
238 Values of attributes listed in _repr_sensitive will be replaced by *###*.
242 Values which repr() string is longer than _repr_length_limit will be
239 Values which repr() string is longer than _repr_length_limit will be
243 represented as *ClassName(...)*
240 represented as *ClassName(...)*
244
241
245 """
242 """
246
243
247 #: Iterable of attributes to be ignored.
244 #: Iterable of attributes to be ignored.
248 _repr_ignore = []
245 _repr_ignore = []
249 #: Iterable of attributes which value should not be visible.
246 #: Iterable of attributes which value should not be visible.
250 _repr_sensitive = []
247 _repr_sensitive = []
251 #: `int` Values longer than this will be truncated to *ClassName(...)*.
248 #: `int` Values longer than this will be truncated to *ClassName(...)*.
252 _repr_length_limit = 20
249 _repr_length_limit = 20
253
250
254 def __repr__(self):
251 def __repr__(self):
255
252
256 # get class name
253 # get class name
257 name = self.__class__.__name__
254 name = self.__class__.__name__
258
255
259 # construct keyword arguments
256 # construct keyword arguments
260 args = []
257 args = []
261
258
262 for k, v in list(self.__dict__.items()):
259 for k, v in list(self.__dict__.items()):
263
260
264 # ignore attributes with leading underscores and those listed in
261 # ignore attributes with leading underscores and those listed in
265 # _repr_ignore
262 # _repr_ignore
266 if v and not k.startswith('_') and k not in self._repr_ignore:
263 if v and not k.startswith('_') and k not in self._repr_ignore:
267
264
268 # replace sensitive values
265 # replace sensitive values
269 if k in self._repr_sensitive:
266 if k in self._repr_sensitive:
270 v = '###'
267 v = '###'
271
268
272 # if repr is too long
269 # if repr is too long
273 if len(repr(v)) > self._repr_length_limit:
270 if len(repr(v)) > self._repr_length_limit:
274 # Truncate to ClassName(...)
271 # Truncate to ClassName(...)
275 v = '{0}(...)'.format(v.__class__.__name__)
272 v = '{0}(...)'.format(v.__class__.__name__)
276 else:
273 else:
277 v = repr(v)
274 v = repr(v)
278
275
279 args.append('{0}={1}'.format(k, v))
276 args.append('{0}={1}'.format(k, v))
280
277
281 return '{0}({1})'.format(name, ', '.join(args))
278 return '{0}({1})'.format(name, ', '.join(args))
282
279
283
280
284 class Future(threading.Thread):
281 class Future(threading.Thread):
285 """
282 """
286 Represents an activity run in a separate thread. Subclasses the standard
283 Represents an activity run in a separate thread. Subclasses the standard
287 library :class:`threading.Thread` and adds :attr:`.get_result` method.
284 library :class:`threading.Thread` and adds :attr:`.get_result` method.
288
285
289 .. warning::
286 .. warning::
290
287
291 |async|
288 |async|
292
289
293 """
290 """
294
291
295 def __init__(self, func, *args, **kwargs):
292 def __init__(self, func, *args, **kwargs):
296 """
293 """
297 :param callable func:
294 :param callable func:
298 The function to be run in separate thread.
295 The function to be run in separate thread.
299
296
300 Calls :data:`func` in separate thread and returns immediately.
297 Calls :data:`func` in separate thread and returns immediately.
301 Accepts arbitrary positional and keyword arguments which will be
298 Accepts arbitrary positional and keyword arguments which will be
302 passed to :data:`func`.
299 passed to :data:`func`.
303 """
300 """
304
301
305 super(Future, self).__init__()
302 super(Future, self).__init__()
306 self._func = func
303 self._func = func
307 self._args = args
304 self._args = args
308 self._kwargs = kwargs
305 self._kwargs = kwargs
309 self._result = None
306 self._result = None
310
307
311 self.start()
308 self.start()
312
309
313 def run(self):
310 def run(self):
314 self._result = self._func(*self._args, **self._kwargs)
311 self._result = self._func(*self._args, **self._kwargs)
315
312
316 def get_result(self, timeout=None):
313 def get_result(self, timeout=None):
317 """
314 """
318 Waits for the wrapped :data:`func` to finish and returns its result.
315 Waits for the wrapped :data:`func` to finish and returns its result.
319
316
320 .. note::
317 .. note::
321
318
322 This will block the **calling thread** until the :data:`func`
319 This will block the **calling thread** until the :data:`func`
323 returns.
320 returns.
324
321
325 :param timeout:
322 :param timeout:
326 :class:`float` or ``None`` A timeout for the :data:`func` to
323 :class:`float` or ``None`` A timeout for the :data:`func` to
327 return in seconds.
324 return in seconds.
328
325
329 :returns:
326 :returns:
330 The result of the wrapped :data:`func`.
327 The result of the wrapped :data:`func`.
331
328
332 """
329 """
333
330
334 self.join(timeout)
331 self.join(timeout)
335 return self._result
332 return self._result
336
333
337
334
338 class Session(object):
335 class Session(object):
339 """
336 """
340 A dictionary-like secure cookie session implementation.
337 A dictionary-like secure cookie session implementation.
341 """
338 """
342
339
343 def __init__(self, adapter, secret, name='authomatic', max_age=600,
340 def __init__(self, adapter, secret, name='authomatic', max_age=600,
344 secure=False):
341 secure=False):
345 """
342 """
346 :param str secret:
343 :param str secret:
347 Session secret used to sign the session cookie.
344 Session secret used to sign the session cookie.
348 :param str name:
345 :param str name:
349 Session cookie name.
346 Session cookie name.
350 :param int max_age:
347 :param int max_age:
351 Maximum allowed age of session cookie nonce in seconds.
348 Maximum allowed age of session cookie nonce in seconds.
352 :param bool secure:
349 :param bool secure:
353 If ``True`` the session cookie will be saved with ``Secure``
350 If ``True`` the session cookie will be saved with ``Secure``
354 attribute.
351 attribute.
355 """
352 """
356
353
357 self.adapter = adapter
354 self.adapter = adapter
358 self.name = name
355 self.name = name
359 self.secret = secret
356 self.secret = secret
360 self.max_age = max_age
357 self.max_age = max_age
361 self.secure = secure
358 self.secure = secure
362 self._data = {}
359 self._data = {}
363
360
364 def create_cookie(self, delete=None):
361 def create_cookie(self, delete=None):
365 """
362 """
366 Creates the value for ``Set-Cookie`` HTTP header.
363 Creates the value for ``Set-Cookie`` HTTP header.
367
364
368 :param bool delete:
365 :param bool delete:
369 If ``True`` the cookie value will be ``deleted`` and the
366 If ``True`` the cookie value will be ``deleted`` and the
370 Expires value will be ``Thu, 01-Jan-1970 00:00:01 GMT``.
367 Expires value will be ``Thu, 01-Jan-1970 00:00:01 GMT``.
371
368
372 """
369 """
373 value = 'deleted' if delete else self._serialize(self.data)
370 value = 'deleted' if delete else self._serialize(self.data)
374 split_url = parse.urlsplit(self.adapter.url)
371 split_url = parse.urlsplit(self.adapter.url)
375 domain = split_url.netloc.split(':')[0]
372 domain = split_url.netloc.split(':')[0]
376
373
377 # Work-around for issue #11, failure of WebKit-based browsers to accept
374 # Work-around for issue #11, failure of WebKit-based browsers to accept
378 # cookies set as part of a redirect response in some circumstances.
375 # cookies set as part of a redirect response in some circumstances.
379 if '.' not in domain:
376 if '.' not in domain:
380 template = '{name}={value}; Path={path}; HttpOnly{secure}{expires}'
377 template = '{name}={value}; Path={path}; HttpOnly{secure}{expires}'
381 else:
378 else:
382 template = ('{name}={value}; Domain={domain}; Path={path}; '
379 template = ('{name}={value}; Domain={domain}; Path={path}; '
383 'HttpOnly{secure}{expires}')
380 'HttpOnly{secure}{expires}')
384
381
385 return template.format(
382 return template.format(
386 name=self.name,
383 name=self.name,
387 value=value,
384 value=value,
388 domain=domain,
385 domain=domain,
389 path=split_url.path,
386 path=split_url.path,
390 secure='; Secure' if self.secure else '',
387 secure='; Secure' if self.secure else '',
391 expires='; Expires=Thu, 01-Jan-1970 00:00:01 GMT' if delete else ''
388 expires='; Expires=Thu, 01-Jan-1970 00:00:01 GMT' if delete else ''
392 )
389 )
393
390
394 def save(self):
391 def save(self):
395 """
392 """
396 Adds the session cookie to headers.
393 Adds the session cookie to headers.
397 """
394 """
398 if self.data:
395 if self.data:
399 cookie = self.create_cookie()
396 cookie = self.create_cookie()
400 cookie_len = len(cookie)
397 cookie_len = len(cookie)
401
398
402 if cookie_len > 4093:
399 if cookie_len > 4093:
403 raise SessionError('Cookie too long! The cookie size {0} '
400 raise SessionError('Cookie too long! The cookie size {0} '
404 'is more than 4093 bytes.'
401 'is more than 4093 bytes.'
405 .format(cookie_len))
402 .format(cookie_len))
406
403
407 self.adapter.set_header('Set-Cookie', cookie)
404 self.adapter.set_header('Set-Cookie', cookie)
408
405
409 # Reset data
406 # Reset data
410 self._data = {}
407 self._data = {}
411
408
412 def delete(self):
409 def delete(self):
413 self.adapter.set_header('Set-Cookie', self.create_cookie(delete=True))
410 self.adapter.set_header('Set-Cookie', self.create_cookie(delete=True))
414
411
415 def _get_data(self):
412 def _get_data(self):
416 """
413 """
417 Extracts the session data from cookie.
414 Extracts the session data from cookie.
418 """
415 """
419 cookie = self.adapter.cookies.get(self.name)
416 cookie = self.adapter.cookies.get(self.name)
420 return self._deserialize(cookie) if cookie else {}
417 return self._deserialize(cookie) if cookie else {}
421
418
422 @property
419 @property
423 def data(self):
420 def data(self):
424 """
421 """
425 Gets session data lazily.
422 Gets session data lazily.
426 """
423 """
427 if not self._data:
424 if not self._data:
428 self._data = self._get_data()
425 self._data = self._get_data()
429 # Always return a dict, even if deserialization returned nothing
426 # Always return a dict, even if deserialization returned nothing
430 if self._data is None:
427 if self._data is None:
431 self._data = {}
428 self._data = {}
432 return self._data
429 return self._data
433
430
434 def _signature(self, *parts):
431 def _signature(self, *parts):
435 """
432 """
436 Creates signature for the session.
433 Creates signature for the session.
437 """
434 """
438 signature = hmac.new(six.b(self.secret), digestmod=hashlib.sha1)
435 signature = hmac.new(six.b(self.secret), digestmod=hashlib.sha1)
439 signature.update(six.b('|'.join(parts)))
436 signature.update(six.b('|'.join(parts)))
440 return signature.hexdigest()
437 return signature.hexdigest()
441
438
442 def _serialize(self, value):
439 def _serialize(self, value):
443 """
440 """
444 Converts the value to a signed string with timestamp.
441 Converts the value to a signed string with timestamp.
445
442
446 :param value:
443 :param value:
447 Object to be serialized.
444 Object to be serialized.
448
445
449 :returns:
446 :returns:
450 Serialized value.
447 Serialized value.
451
448
452 """
449 """
453
450
454 # data = copy.deepcopy(value)
451 # data = copy.deepcopy(value)
455 data = value
452 data = value
456
453
457 # 1. Serialize
454 # 1. Serialize
458 serialized = pickle.dumps(data).decode('latin-1')
455 serialized = pickle.dumps(data).decode('latin-1')
459
456
460 # 2. Encode
457 # 2. Encode
461 # Percent encoding produces smaller result then urlsafe base64.
458 # Percent encoding produces smaller result then urlsafe base64.
462 encoded = parse.quote(serialized, '')
459 encoded = parse.quote(serialized, '')
463
460
464 # 3. Concatenate
461 # 3. Concatenate
465 timestamp = str(int(time.time()))
462 timestamp = str(int(time.time()))
466 signature = self._signature(self.name, encoded, timestamp)
463 signature = self._signature(self.name, encoded, timestamp)
467 concatenated = '|'.join([encoded, timestamp, signature])
464 concatenated = '|'.join([encoded, timestamp, signature])
468
465
469 return concatenated
466 return concatenated
470
467
471 def _deserialize(self, value):
468 def _deserialize(self, value):
472 """
469 """
473 Deserializes and verifies the value created by :meth:`._serialize`.
470 Deserializes and verifies the value created by :meth:`._serialize`.
474
471
475 :param str value:
472 :param str value:
476 The serialized value.
473 The serialized value.
477
474
478 :returns:
475 :returns:
479 Deserialized object.
476 Deserialized object.
480
477
481 """
478 """
482
479
483 # 3. Split
480 # 3. Split
484 encoded, timestamp, signature = value.split('|')
481 encoded, timestamp, signature = value.split('|')
485
482
486 # Verify signature
483 # Verify signature
487 if not signature == self._signature(self.name, encoded, timestamp):
484 if not signature == self._signature(self.name, encoded, timestamp):
488 raise SessionError('Invalid signature "{0}"!'.format(signature))
485 raise SessionError('Invalid signature "{0}"!'.format(signature))
489
486
490 # Verify timestamp
487 # Verify timestamp
491 if int(timestamp) < int(time.time()) - self.max_age:
488 if int(timestamp) < int(time.time()) - self.max_age:
492 return None
489 return None
493
490
494 # 2. Decode
491 # 2. Decode
495 decoded = parse.unquote(encoded)
492 decoded = parse.unquote(encoded)
496
493
497 # 1. Deserialize
494 # 1. Deserialize
498 deserialized = pickle.loads(decoded.encode('latin-1'))
495 deserialized = pickle.loads(decoded.encode('latin-1'))
499
496
500 return deserialized
497 return deserialized
501
498
502 def __setitem__(self, key, value):
499 def __setitem__(self, key, value):
503 self._data[key] = value
500 self._data[key] = value
504
501
505 def __getitem__(self, key):
502 def __getitem__(self, key):
506 return self.data.__getitem__(key)
503 return self.data.__getitem__(key)
507
504
508 def __delitem__(self, key):
505 def __delitem__(self, key):
509 return self._data.__delitem__(key)
506 return self._data.__delitem__(key)
510
507
511 def get(self, key, default=None):
508 def get(self, key, default=None):
512 return self.data.get(key, default)
509 return self.data.get(key, default)
513
510
514
511
515 class User(ReprMixin):
512 class User(ReprMixin):
516 """
513 """
517 Provides unified interface to selected **user** info returned by different
514 Provides unified interface to selected **user** info returned by different
518 **providers**.
515 **providers**.
519
516
520 .. note:: The value format may vary across providers.
517 .. note:: The value format may vary across providers.
521
518
522 """
519 """
523
520
524 def __init__(self, provider, **kwargs):
521 def __init__(self, provider, **kwargs):
525 #: A :doc:`provider <providers>` instance.
522 #: A :doc:`provider <providers>` instance.
526 self.provider = provider
523 self.provider = provider
527
524
528 #: An :class:`.Credentials` instance.
525 #: An :class:`.Credentials` instance.
529 self.credentials = kwargs.get('credentials')
526 self.credentials = kwargs.get('credentials')
530
527
531 #: A :class:`dict` containing all the **user** information returned
528 #: A :class:`dict` containing all the **user** information returned
532 #: by the **provider**.
529 #: by the **provider**.
533 #: The structure differs across **providers**.
530 #: The structure differs across **providers**.
534 self.data = kwargs.get('data')
531 self.data = kwargs.get('data')
535
532
536 #: The :attr:`.Response.content` of the request made to update
533 #: The :attr:`.Response.content` of the request made to update
537 #: the user.
534 #: the user.
538 self.content = kwargs.get('content')
535 self.content = kwargs.get('content')
539
536
540 #: :class:`str` ID assigned to the **user** by the **provider**.
537 #: :class:`str` ID assigned to the **user** by the **provider**.
541 self.id = kwargs.get('id')
538 self.id = kwargs.get('id')
542 #: :class:`str` User name e.g. *andrewpipkin*.
539 #: :class:`str` User name e.g. *andrewpipkin*.
543 self.username = kwargs.get('username')
540 self.username = kwargs.get('username')
544 #: :class:`str` Name e.g. *Andrew Pipkin*.
541 #: :class:`str` Name e.g. *Andrew Pipkin*.
545 self.name = kwargs.get('name')
542 self.name = kwargs.get('name')
546 #: :class:`str` First name e.g. *Andrew*.
543 #: :class:`str` First name e.g. *Andrew*.
547 self.first_name = kwargs.get('first_name')
544 self.first_name = kwargs.get('first_name')
548 #: :class:`str` Last name e.g. *Pipkin*.
545 #: :class:`str` Last name e.g. *Pipkin*.
549 self.last_name = kwargs.get('last_name')
546 self.last_name = kwargs.get('last_name')
550 #: :class:`str` Nickname e.g. *Andy*.
547 #: :class:`str` Nickname e.g. *Andy*.
551 self.nickname = kwargs.get('nickname')
548 self.nickname = kwargs.get('nickname')
552 #: :class:`str` Link URL.
549 #: :class:`str` Link URL.
553 self.link = kwargs.get('link')
550 self.link = kwargs.get('link')
554 #: :class:`str` Gender.
551 #: :class:`str` Gender.
555 self.gender = kwargs.get('gender')
552 self.gender = kwargs.get('gender')
556 #: :class:`str` Timezone.
553 #: :class:`str` Timezone.
557 self.timezone = kwargs.get('timezone')
554 self.timezone = kwargs.get('timezone')
558 #: :class:`str` Locale.
555 #: :class:`str` Locale.
559 self.locale = kwargs.get('locale')
556 self.locale = kwargs.get('locale')
560 #: :class:`str` E-mail.
557 #: :class:`str` E-mail.
561 self.email = kwargs.get('email')
558 self.email = kwargs.get('email')
562 #: :class:`str` phone.
559 #: :class:`str` phone.
563 self.phone = kwargs.get('phone')
560 self.phone = kwargs.get('phone')
564 #: :class:`str` Picture URL.
561 #: :class:`str` Picture URL.
565 self.picture = kwargs.get('picture')
562 self.picture = kwargs.get('picture')
566 #: Birth date as :class:`datetime.datetime()` or :class:`str`
563 #: Birth date as :class:`datetime.datetime()` or :class:`str`
567 # if parsing failed or ``None``.
564 # if parsing failed or ``None``.
568 self.birth_date = kwargs.get('birth_date')
565 self.birth_date = kwargs.get('birth_date')
569 #: :class:`str` Country.
566 #: :class:`str` Country.
570 self.country = kwargs.get('country')
567 self.country = kwargs.get('country')
571 #: :class:`str` City.
568 #: :class:`str` City.
572 self.city = kwargs.get('city')
569 self.city = kwargs.get('city')
573 #: :class:`str` Geographical location.
570 #: :class:`str` Geographical location.
574 self.location = kwargs.get('location')
571 self.location = kwargs.get('location')
575 #: :class:`str` Postal code.
572 #: :class:`str` Postal code.
576 self.postal_code = kwargs.get('postal_code')
573 self.postal_code = kwargs.get('postal_code')
577 #: Instance of the Google App Engine Users API
574 #: Instance of the Google App Engine Users API
578 #: `User <https://developers.google.com/appengine/docs/python/users/userclass>`_ class.
575 #: `User <https://developers.google.com/appengine/docs/python/users/userclass>`_ class.
579 #: Only present when using the :class:`authomatic.providers.gaeopenid.GAEOpenID` provider.
576 #: Only present when using the :class:`authomatic.providers.gaeopenid.GAEOpenID` provider.
580 self.gae_user = kwargs.get('gae_user')
577 self.gae_user = kwargs.get('gae_user')
581
578
582 def update(self):
579 def update(self):
583 """
580 """
584 Updates the user info by fetching the **provider's** user info URL.
581 Updates the user info by fetching the **provider's** user info URL.
585
582
586 :returns:
583 :returns:
587 Updated instance of this class.
584 Updated instance of this class.
588
585
589 """
586 """
590
587
591 return self.provider.update_user()
588 return self.provider.update_user()
592
589
593 def async_update(self):
590 def async_update(self):
594 """
591 """
595 Same as :meth:`.update` but runs asynchronously in a separate thread.
592 Same as :meth:`.update` but runs asynchronously in a separate thread.
596
593
597 .. warning::
594 .. warning::
598
595
599 |async|
596 |async|
600
597
601 :returns:
598 :returns:
602 :class:`.Future` instance representing the separate thread.
599 :class:`.Future` instance representing the separate thread.
603
600
604 """
601 """
605
602
606 return Future(self.update)
603 return Future(self.update)
607
604
608 def to_dict(self):
605 def to_dict(self):
609 """
606 """
610 Converts the :class:`.User` instance to a :class:`dict`.
607 Converts the :class:`.User` instance to a :class:`dict`.
611
608
612 :returns:
609 :returns:
613 :class:`dict`
610 :class:`dict`
614
611
615 """
612 """
616
613
617 # copy the dictionary
614 # copy the dictionary
618 d = copy.copy(self.__dict__)
615 d = copy.copy(self.__dict__)
619
616
620 # Keep only the provider name to avoid circular reference
617 # Keep only the provider name to avoid circular reference
621 d['provider'] = self.provider.name
618 d['provider'] = self.provider.name
622 d['credentials'] = self.credentials.serialize(
619 d['credentials'] = self.credentials.serialize(
623 ) if self.credentials else None
620 ) if self.credentials else None
624 d['birth_date'] = str(d['birth_date'])
621 d['birth_date'] = str(d['birth_date'])
625
622
626 # Remove content
623 # Remove content
627 d.pop('content')
624 d.pop('content')
628
625
629 if isinstance(self.data, ElementTree.Element):
626 if isinstance(self.data, ElementTree.Element):
630 d['data'] = None
627 d['data'] = None
631
628
632 return d
629 return d
633
630
634
631
635 SupportedUserAttributesNT = collections.namedtuple(
632 SupportedUserAttributesNT = collections.namedtuple(
636 typename='SupportedUserAttributesNT',
633 typename='SupportedUserAttributesNT',
637 field_names=['birth_date', 'city', 'country', 'email', 'first_name',
634 field_names=['birth_date', 'city', 'country', 'email', 'first_name',
638 'gender', 'id', 'last_name', 'link', 'locale', 'location',
635 'gender', 'id', 'last_name', 'link', 'locale', 'location',
639 'name', 'nickname', 'phone', 'picture', 'postal_code',
636 'name', 'nickname', 'phone', 'picture', 'postal_code',
640 'timezone', 'username', ]
637 'timezone', 'username', ]
641 )
638 )
642
639
643
640
644 class SupportedUserAttributes(SupportedUserAttributesNT):
641 class SupportedUserAttributes(SupportedUserAttributesNT):
645 def __new__(cls, **kwargs):
642 def __new__(cls, **kwargs):
646 defaults = dict((i, False) for i in SupportedUserAttributes._fields) # pylint:disable=no-member
643 defaults = dict((i, False) for i in SupportedUserAttributes._fields) # pylint:disable=no-member
647 defaults.update(**kwargs)
644 defaults.update(**kwargs)
648 return super(SupportedUserAttributes, cls).__new__(cls, **defaults)
645 return super(SupportedUserAttributes, cls).__new__(cls, **defaults)
649
646
650
647
651 class Credentials(ReprMixin):
648 class Credentials(ReprMixin):
652 """
649 """
653 Contains all necessary information to fetch **user's protected resources**.
650 Contains all necessary information to fetch **user's protected resources**.
654 """
651 """
655
652
656 _repr_sensitive = ('token', 'refresh_token', 'token_secret',
653 _repr_sensitive = ('token', 'refresh_token', 'token_secret',
657 'consumer_key', 'consumer_secret')
654 'consumer_key', 'consumer_secret')
658
655
659 def __init__(self, config, **kwargs):
656 def __init__(self, config, **kwargs):
660
657
661 #: :class:`dict` :doc:`config`.
658 #: :class:`dict` :doc:`config`.
662 self.config = config
659 self.config = config
663
660
664 #: :class:`str` User **access token**.
661 #: :class:`str` User **access token**.
665 self.token = kwargs.get('token', '')
662 self.token = kwargs.get('token', '')
666
663
667 #: :class:`str` Access token type.
664 #: :class:`str` Access token type.
668 self.token_type = kwargs.get('token_type', '')
665 self.token_type = kwargs.get('token_type', '')
669
666
670 #: :class:`str` Refresh token.
667 #: :class:`str` Refresh token.
671 self.refresh_token = kwargs.get('refresh_token', '')
668 self.refresh_token = kwargs.get('refresh_token', '')
672
669
673 #: :class:`str` Access token secret.
670 #: :class:`str` Access token secret.
674 self.token_secret = kwargs.get('token_secret', '')
671 self.token_secret = kwargs.get('token_secret', '')
675
672
676 #: :class:`int` Expiration date as UNIX timestamp.
673 #: :class:`int` Expiration date as UNIX timestamp.
677 self.expiration_time = int(kwargs.get('expiration_time', 0))
674 self.expiration_time = int(kwargs.get('expiration_time', 0))
678
675
679 #: A :doc:`Provider <providers>` instance**.
676 #: A :doc:`Provider <providers>` instance**.
680 provider = kwargs.get('provider')
677 provider = kwargs.get('provider')
681
678
682 self.expire_in = int(kwargs.get('expire_in', 0))
679 self.expire_in = int(kwargs.get('expire_in', 0))
683
680
684 if provider:
681 if provider:
685 #: :class:`str` Provider name specified in the :doc:`config`.
682 #: :class:`str` Provider name specified in the :doc:`config`.
686 self.provider_name = provider.name
683 self.provider_name = provider.name
687
684
688 #: :class:`str` Provider type e.g.
685 #: :class:`str` Provider type e.g.
689 # ``"authomatic.providers.oauth2.OAuth2"``.
686 # ``"authomatic.providers.oauth2.OAuth2"``.
690 self.provider_type = provider.get_type()
687 self.provider_type = provider.get_type()
691
688
692 #: :class:`str` Provider type e.g.
689 #: :class:`str` Provider type e.g.
693 # ``"authomatic.providers.oauth2.OAuth2"``.
690 # ``"authomatic.providers.oauth2.OAuth2"``.
694 self.provider_type_id = provider.type_id
691 self.provider_type_id = provider.type_id
695
692
696 #: :class:`str` Provider short name specified in the :doc:`config`.
693 #: :class:`str` Provider short name specified in the :doc:`config`.
697 self.provider_id = int(provider.id) if provider.id else None
694 self.provider_id = int(provider.id) if provider.id else None
698
695
699 #: :class:`class` Provider class.
696 #: :class:`class` Provider class.
700 self.provider_class = provider.__class__
697 self.provider_class = provider.__class__
701
698
702 #: :class:`str` Consumer key specified in the :doc:`config`.
699 #: :class:`str` Consumer key specified in the :doc:`config`.
703 self.consumer_key = provider.consumer_key
700 self.consumer_key = provider.consumer_key
704
701
705 #: :class:`str` Consumer secret specified in the :doc:`config`.
702 #: :class:`str` Consumer secret specified in the :doc:`config`.
706 self.consumer_secret = provider.consumer_secret
703 self.consumer_secret = provider.consumer_secret
707
704
708 else:
705 else:
709 self.provider_name = kwargs.get('provider_name', '')
706 self.provider_name = kwargs.get('provider_name', '')
710 self.provider_type = kwargs.get('provider_type', '')
707 self.provider_type = kwargs.get('provider_type', '')
711 self.provider_type_id = kwargs.get('provider_type_id')
708 self.provider_type_id = kwargs.get('provider_type_id')
712 self.provider_id = kwargs.get('provider_id')
709 self.provider_id = kwargs.get('provider_id')
713 self.provider_class = kwargs.get('provider_class')
710 self.provider_class = kwargs.get('provider_class')
714
711
715 self.consumer_key = kwargs.get('consumer_key', '')
712 self.consumer_key = kwargs.get('consumer_key', '')
716 self.consumer_secret = kwargs.get('consumer_secret', '')
713 self.consumer_secret = kwargs.get('consumer_secret', '')
717
714
718 @property
715 @property
719 def expire_in(self):
716 def expire_in(self):
720 """
717 """
721
718
722 """
719 """
723
720
724 return self._expire_in
721 return self._expire_in
725
722
726 @expire_in.setter
723 @expire_in.setter
727 def expire_in(self, value):
724 def expire_in(self, value):
728 """
725 """
729 Computes :attr:`.expiration_time` when the value is set.
726 Computes :attr:`.expiration_time` when the value is set.
730 """
727 """
731
728
732 # pylint:disable=attribute-defined-outside-init
729 # pylint:disable=attribute-defined-outside-init
733 if value:
730 if value:
734 self._expiration_time = int(time.time()) + int(value)
731 self._expiration_time = int(time.time()) + int(value)
735 self._expire_in = value
732 self._expire_in = value
736
733
737 @property
734 @property
738 def expiration_time(self):
735 def expiration_time(self):
739 return self._expiration_time
736 return self._expiration_time
740
737
741 @expiration_time.setter
738 @expiration_time.setter
742 def expiration_time(self, value):
739 def expiration_time(self, value):
743
740
744 # pylint:disable=attribute-defined-outside-init
741 # pylint:disable=attribute-defined-outside-init
745 self._expiration_time = int(value)
742 self._expiration_time = int(value)
746 self._expire_in = self._expiration_time - int(time.time())
743 self._expire_in = self._expiration_time - int(time.time())
747
744
748 @property
745 @property
749 def expiration_date(self):
746 def expiration_date(self):
750 """
747 """
751 Expiration date as :class:`datetime.datetime` or ``None`` if
748 Expiration date as :class:`datetime.datetime` or ``None`` if
752 credentials never expire.
749 credentials never expire.
753 """
750 """
754
751
755 if self.expire_in < 0:
752 if self.expire_in < 0:
756 return None
753 return None
757 else:
754 else:
758 return datetime.datetime.fromtimestamp(self.expiration_time)
755 return datetime.datetime.fromtimestamp(self.expiration_time)
759
756
760 @property
757 @property
761 def valid(self):
758 def valid(self):
762 """
759 """
763 ``True`` if credentials are valid, ``False`` if expired.
760 ``True`` if credentials are valid, ``False`` if expired.
764 """
761 """
765
762
766 if self.expiration_time:
763 if self.expiration_time:
767 return self.expiration_time > int(time.time())
764 return self.expiration_time > int(time.time())
768 else:
765 else:
769 return True
766 return True
770
767
771 def expire_soon(self, seconds):
768 def expire_soon(self, seconds):
772 """
769 """
773 Returns ``True`` if credentials expire sooner than specified.
770 Returns ``True`` if credentials expire sooner than specified.
774
771
775 :param int seconds:
772 :param int seconds:
776 Number of seconds.
773 Number of seconds.
777
774
778 :returns:
775 :returns:
779 ``True`` if credentials expire sooner than specified,
776 ``True`` if credentials expire sooner than specified,
780 else ``False``.
777 else ``False``.
781
778
782 """
779 """
783
780
784 if self.expiration_time:
781 if self.expiration_time:
785 return self.expiration_time < int(time.time()) + int(seconds)
782 return self.expiration_time < int(time.time()) + int(seconds)
786 else:
783 else:
787 return False
784 return False
788
785
789 def refresh(self, force=False, soon=86400):
786 def refresh(self, force=False, soon=86400):
790 """
787 """
791 Refreshes the credentials only if the **provider** supports it and if
788 Refreshes the credentials only if the **provider** supports it and if
792 it will expire in less than one day. It does nothing in other cases.
789 it will expire in less than one day. It does nothing in other cases.
793
790
794 .. note::
791 .. note::
795
792
796 The credentials will be refreshed only if it gives sense
793 The credentials will be refreshed only if it gives sense
797 i.e. only |oauth2|_ has the notion of credentials
794 i.e. only |oauth2|_ has the notion of credentials
798 *refreshment/extension*.
795 *refreshment/extension*.
799 And there are also differences across providers e.g. Google
796 And there are also differences across providers e.g. Google
800 supports refreshment only if there is a ``refresh_token`` in
797 supports refreshment only if there is a ``refresh_token`` in
801 the credentials and that in turn is present only if the
798 the credentials and that in turn is present only if the
802 ``access_type`` parameter was set to ``offline`` in the
799 ``access_type`` parameter was set to ``offline`` in the
803 **user authorization request**.
800 **user authorization request**.
804
801
805 :param bool force:
802 :param bool force:
806 If ``True`` the credentials will be refreshed even if they
803 If ``True`` the credentials will be refreshed even if they
807 won't expire soon.
804 won't expire soon.
808
805
809 :param int soon:
806 :param int soon:
810 Number of seconds specifying what means *soon*.
807 Number of seconds specifying what means *soon*.
811
808
812 """
809 """
813
810
814 if hasattr(self.provider_class, 'refresh_credentials'):
811 if hasattr(self.provider_class, 'refresh_credentials'):
815 if force or self.expire_soon(soon):
812 if force or self.expire_soon(soon):
816 logging.info('PROVIDER NAME: {0}'.format(self.provider_name))
813 logging.info('PROVIDER NAME: {0}'.format(self.provider_name))
817 return self.provider_class(
814 return self.provider_class(
818 self, None, self.provider_name).refresh_credentials(self)
815 self, None, self.provider_name).refresh_credentials(self)
819
816
820 def async_refresh(self, *args, **kwargs):
817 def async_refresh(self, *args, **kwargs):
821 """
818 """
822 Same as :meth:`.refresh` but runs asynchronously in a separate thread.
819 Same as :meth:`.refresh` but runs asynchronously in a separate thread.
823
820
824 .. warning::
821 .. warning::
825
822
826 |async|
823 |async|
827
824
828 :returns:
825 :returns:
829 :class:`.Future` instance representing the separate thread.
826 :class:`.Future` instance representing the separate thread.
830
827
831 """
828 """
832
829
833 return Future(self.refresh, *args, **kwargs)
830 return Future(self.refresh, *args, **kwargs)
834
831
835 def provider_type_class(self):
832 def provider_type_class(self):
836 """
833 """
837 Returns the :doc:`provider <providers>` class specified in the
834 Returns the :doc:`provider <providers>` class specified in the
838 :doc:`config`.
835 :doc:`config`.
839
836
840 :returns:
837 :returns:
841 :class:`authomatic.providers.BaseProvider` subclass.
838 :class:`authomatic.providers.BaseProvider` subclass.
842
839
843 """
840 """
844
841
845 return resolve_provider_class(self.provider_type)
842 return resolve_provider_class(self.provider_type)
846
843
847 def serialize(self):
844 def serialize(self):
848 """
845 """
849 Converts the credentials to a percent encoded string to be stored for
846 Converts the credentials to a percent encoded string to be stored for
850 later use.
847 later use.
851
848
852 :returns:
849 :returns:
853 :class:`string`
850 :class:`string`
854
851
855 """
852 """
856
853
857 if self.provider_id is None:
854 if self.provider_id is None:
858 raise ConfigError(
855 raise ConfigError(
859 'To serialize credentials you need to specify a '
856 'To serialize credentials you need to specify a '
860 'unique integer under the "id" key in the config '
857 'unique integer under the "id" key in the config '
861 'for each provider!')
858 'for each provider!')
862
859
863 # Get the provider type specific items.
860 # Get the provider type specific items.
864 rest = self.provider_type_class().to_tuple(self)
861 rest = self.provider_type_class().to_tuple(self)
865
862
866 # Provider ID and provider type ID are always the first two items.
863 # Provider ID and provider type ID are always the first two items.
867 result = (self.provider_id, self.provider_type_id) + rest
864 result = (self.provider_id, self.provider_type_id) + rest
868
865
869 # Make sure that all items are strings.
866 # Make sure that all items are strings.
870 stringified = [str(i) for i in result]
867 stringified = [str(i) for i in result]
871
868
872 # Concatenate by newline.
869 # Concatenate by newline.
873 concatenated = '\n'.join(stringified)
870 concatenated = '\n'.join(stringified)
874
871
875 # Percent encode.
872 # Percent encode.
876 return parse.quote(concatenated, '')
873 return parse.quote(concatenated, '')
877
874
878 @classmethod
875 @classmethod
879 def deserialize(cls, config, credentials):
876 def deserialize(cls, config, credentials):
880 """
877 """
881 A *class method* which reconstructs credentials created by
878 A *class method* which reconstructs credentials created by
882 :meth:`serialize`. You can also pass it a :class:`.Credentials`
879 :meth:`serialize`. You can also pass it a :class:`.Credentials`
883 instance.
880 instance.
884
881
885 :param dict config:
882 :param dict config:
886 The same :doc:`config` used in the :func:`.login` to get the
883 The same :doc:`config` used in the :func:`.login` to get the
887 credentials.
884 credentials.
888 :param str credentials:
885 :param str credentials:
889 :class:`string` The serialized credentials or
886 :class:`string` The serialized credentials or
890 :class:`.Credentials` instance.
887 :class:`.Credentials` instance.
891
888
892 :returns:
889 :returns:
893 :class:`.Credentials`
890 :class:`.Credentials`
894
891
895 """
892 """
896
893
897 # Accept both serialized and normal.
894 # Accept both serialized and normal.
898 if isinstance(credentials, Credentials):
895 if isinstance(credentials, Credentials):
899 return credentials
896 return credentials
900
897
901 decoded = parse.unquote(credentials)
898 decoded = parse.unquote(credentials)
902
899
903 split = decoded.split('\n')
900 split = decoded.split('\n')
904
901
905 # We need the provider ID to move forward.
902 # We need the provider ID to move forward.
906 if split[0] is None:
903 if split[0] is None:
907 raise CredentialsError(
904 raise CredentialsError(
908 'To deserialize credentials you need to specify a unique '
905 'To deserialize credentials you need to specify a unique '
909 'integer under the "id" key in the config for each provider!')
906 'integer under the "id" key in the config for each provider!')
910
907
911 # Get provider config by short name.
908 # Get provider config by short name.
912 provider_name = id_to_name(config, int(split[0]))
909 provider_name = id_to_name(config, int(split[0]))
913 cfg = config.get(provider_name)
910 cfg = config.get(provider_name)
914
911
915 # Get the provider class.
912 # Get the provider class.
916 ProviderClass = resolve_provider_class(cfg.get('class_'))
913 ProviderClass = resolve_provider_class(cfg.get('class_'))
917
914
918 deserialized = Credentials(config)
915 deserialized = Credentials(config)
919
916
920 deserialized.provider_id = provider_id
917 deserialized.provider_id = provider_id
921 deserialized.provider_type = ProviderClass.get_type()
918 deserialized.provider_type = ProviderClass.get_type()
922 deserialized.provider_type_id = split[1]
919 deserialized.provider_type_id = split[1]
923 deserialized.provider_class = ProviderClass
920 deserialized.provider_class = ProviderClass
924 deserialized.provider_name = provider_name
921 deserialized.provider_name = provider_name
925 deserialized.provider_class = ProviderClass
922 deserialized.provider_class = ProviderClass
926
923
927 # Add provider type specific properties.
924 # Add provider type specific properties.
928 return ProviderClass.reconstruct(split[2:], deserialized, cfg)
925 return ProviderClass.reconstruct(split[2:], deserialized, cfg)
929
926
930
927
931 class LoginResult(ReprMixin):
928 class LoginResult(ReprMixin):
932 """
929 """
933 Result of the :func:`authomatic.login` function.
930 Result of the :func:`authomatic.login` function.
934 """
931 """
935
932
936 def __init__(self, provider):
933 def __init__(self, provider):
937 #: A :doc:`provider <providers>` instance.
934 #: A :doc:`provider <providers>` instance.
938 self.provider = provider
935 self.provider = provider
939
936
940 #: An instance of the :exc:`authomatic.exceptions.BaseError` subclass.
937 #: An instance of the :exc:`authomatic.exceptions.BaseError` subclass.
941 self.error = None
938 self.error = None
942
939
943 def popup_js(self, callback_name=None, indent=None,
940 def popup_js(self, callback_name=None, indent=None,
944 custom=None, stay_open=False):
941 custom=None, stay_open=False):
945 """
942 """
946 Returns JavaScript that:
943 Returns JavaScript that:
947
944
948 #. Triggers the ``options.onLoginComplete(result, closer)``
945 #. Triggers the ``options.onLoginComplete(result, closer)``
949 handler set with the :ref:`authomatic.setup() <js_setup>`
946 handler set with the :ref:`authomatic.setup() <js_setup>`
950 function of :ref:`javascript.js <js>`.
947 function of :ref:`javascript.js <js>`.
951 #. Calls the JavasScript callback specified by :data:`callback_name`
948 #. Calls the JavasScript callback specified by :data:`callback_name`
952 on the opener of the *login handler popup* and passes it the
949 on the opener of the *login handler popup* and passes it the
953 *login result* JSON object as first argument and the `closer`
950 *login result* JSON object as first argument and the `closer`
954 function which you should call in your callback to close the popup.
951 function which you should call in your callback to close the popup.
955
952
956 :param str callback_name:
953 :param str callback_name:
957 The name of the javascript callback e.g ``foo.bar.loginCallback``
954 The name of the javascript callback e.g ``foo.bar.loginCallback``
958 will result in ``window.opener.foo.bar.loginCallback(result);``
955 will result in ``window.opener.foo.bar.loginCallback(result);``
959 in the HTML.
956 in the HTML.
960
957
961 :param int indent:
958 :param int indent:
962 The number of spaces to indent the JSON result object.
959 The number of spaces to indent the JSON result object.
963 If ``0`` or negative, only newlines are added.
960 If ``0`` or negative, only newlines are added.
964 If ``None``, no newlines are added.
961 If ``None``, no newlines are added.
965
962
966 :param custom:
963 :param custom:
967 Any JSON serializable object that will be passed to the
964 Any JSON serializable object that will be passed to the
968 ``result.custom`` attribute.
965 ``result.custom`` attribute.
969
966
970 :param str stay_open:
967 :param str stay_open:
971 If ``True``, the popup will stay open.
968 If ``True``, the popup will stay open.
972
969
973 :returns:
970 :returns:
974 :class:`str` with JavaScript.
971 :class:`str` with JavaScript.
975
972
976 """
973 """
977
974
978 custom_callback = """
975 custom_callback = """
979 try {{ window.opener.{cb}(result, closer); }} catch(e) {{}}
976 try {{ window.opener.{cb}(result, closer); }} catch(e) {{}}
980 """.format(cb=callback_name) if callback_name else ''
977 """.format(cb=callback_name) if callback_name else ''
981
978
982 # TODO: Move the window.close() to the opener
979 # TODO: Move the window.close() to the opener
983 return """
980 return """
984 (function(){{
981 (function(){{
985
982
986 closer = function(){{
983 closer = function(){{
987 window.close();
984 window.close();
988 }};
985 }};
989
986
990 var result = {result};
987 var result = {result};
991 result.custom = {custom};
988 result.custom = {custom};
992
989
993 {custom_callback}
990 {custom_callback}
994
991
995 try {{
992 try {{
996 window.opener.authomatic.loginComplete(result, closer);
993 window.opener.authomatic.loginComplete(result, closer);
997 }} catch(e) {{}}
994 }} catch(e) {{}}
998
995
999 }})();
996 }})();
1000
997
1001 """.format(result=self.to_json(indent),
998 """.format(result=self.to_json(indent),
1002 custom=json.dumps(custom),
999 custom=json.dumps(custom),
1003 custom_callback=custom_callback,
1000 custom_callback=custom_callback,
1004 stay_open='// ' if stay_open else '')
1001 stay_open='// ' if stay_open else '')
1005
1002
1006 def popup_html(self, callback_name=None, indent=None,
1003 def popup_html(self, callback_name=None, indent=None,
1007 title='Login | {0}', custom=None, stay_open=False):
1004 title='Login | {0}', custom=None, stay_open=False):
1008 """
1005 """
1009 Returns a HTML with JavaScript that:
1006 Returns a HTML with JavaScript that:
1010
1007
1011 #. Triggers the ``options.onLoginComplete(result, closer)`` handler
1008 #. Triggers the ``options.onLoginComplete(result, closer)`` handler
1012 set with the :ref:`authomatic.setup() <js_setup>` function of
1009 set with the :ref:`authomatic.setup() <js_setup>` function of
1013 :ref:`javascript.js <js>`.
1010 :ref:`javascript.js <js>`.
1014 #. Calls the JavasScript callback specified by :data:`callback_name`
1011 #. Calls the JavasScript callback specified by :data:`callback_name`
1015 on the opener of the *login handler popup* and passes it the
1012 on the opener of the *login handler popup* and passes it the
1016 *login result* JSON object as first argument and the `closer`
1013 *login result* JSON object as first argument and the `closer`
1017 function which you should call in your callback to close the popup.
1014 function which you should call in your callback to close the popup.
1018
1015
1019 :param str callback_name:
1016 :param str callback_name:
1020 The name of the javascript callback e.g ``foo.bar.loginCallback``
1017 The name of the javascript callback e.g ``foo.bar.loginCallback``
1021 will result in ``window.opener.foo.bar.loginCallback(result);``
1018 will result in ``window.opener.foo.bar.loginCallback(result);``
1022 in the HTML.
1019 in the HTML.
1023
1020
1024 :param int indent:
1021 :param int indent:
1025 The number of spaces to indent the JSON result object.
1022 The number of spaces to indent the JSON result object.
1026 If ``0`` or negative, only newlines are added.
1023 If ``0`` or negative, only newlines are added.
1027 If ``None``, no newlines are added.
1024 If ``None``, no newlines are added.
1028
1025
1029 :param str title:
1026 :param str title:
1030 The text of the HTML title. You can use ``{0}`` tag inside,
1027 The text of the HTML title. You can use ``{0}`` tag inside,
1031 which will be replaced by the provider name.
1028 which will be replaced by the provider name.
1032
1029
1033 :param custom:
1030 :param custom:
1034 Any JSON serializable object that will be passed to the
1031 Any JSON serializable object that will be passed to the
1035 ``result.custom`` attribute.
1032 ``result.custom`` attribute.
1036
1033
1037 :param str stay_open:
1034 :param str stay_open:
1038 If ``True``, the popup will stay open.
1035 If ``True``, the popup will stay open.
1039
1036
1040 :returns:
1037 :returns:
1041 :class:`str` with HTML.
1038 :class:`str` with HTML.
1042
1039
1043 """
1040 """
1044
1041
1045 return """
1042 return """
1046 <!DOCTYPE html>
1043 <!DOCTYPE html>
1047 <html>
1044 <html>
1048 <head><title>{title}</title></head>
1045 <head><title>{title}</title></head>
1049 <body>
1046 <body>
1050 <script type="text/javascript">
1047 <script type="text/javascript">
1051 {js}
1048 {js}
1052 </script>
1049 </script>
1053 </body>
1050 </body>
1054 </html>
1051 </html>
1055 """.format(
1052 """.format(
1056 title=title.format(self.provider.name if self.provider else ''),
1053 title=title.format(self.provider.name if self.provider else ''),
1057 js=self.popup_js(callback_name, indent, custom, stay_open)
1054 js=self.popup_js(callback_name, indent, custom, stay_open)
1058 )
1055 )
1059
1056
1060 @property
1057 @property
1061 def user(self):
1058 def user(self):
1062 """
1059 """
1063 A :class:`.User` instance.
1060 A :class:`.User` instance.
1064 """
1061 """
1065
1062
1066 return self.provider.user if self.provider else None
1063 return self.provider.user if self.provider else None
1067
1064
1068 def to_dict(self):
1065 def to_dict(self):
1069 return dict(provider=self.provider, user=self.user, error=self.error)
1066 return dict(provider=self.provider, user=self.user, error=self.error)
1070
1067
1071 def to_json(self, indent=4):
1068 def to_json(self, indent=4):
1072 return json.dumps(self, default=lambda obj: obj.to_dict(
1069 return json.dumps(self, default=lambda obj: obj.to_dict(
1073 ) if hasattr(obj, 'to_dict') else '', indent=indent)
1070 ) if hasattr(obj, 'to_dict') else '', indent=indent)
1074
1071
1075
1072
1076 class Response(ReprMixin):
1073 class Response(ReprMixin):
1077 """
1074 """
1078 Wraps :class:`httplib.HTTPResponse` and adds.
1075 Wraps :class:`httplib.HTTPResponse` and adds.
1079
1076
1080 :attr:`.content` and :attr:`.data` attributes.
1077 :attr:`.content` and :attr:`.data` attributes.
1081
1078
1082 """
1079 """
1083
1080
1084 def __init__(self, httplib_response, content_parser=None):
1081 def __init__(self, httplib_response, content_parser=None):
1085 """
1082 """
1086 :param httplib_response:
1083 :param httplib_response:
1087 The wrapped :class:`httplib.HTTPResponse` instance.
1084 The wrapped :class:`httplib.HTTPResponse` instance.
1088
1085
1089 :param function content_parser:
1086 :param function content_parser:
1090 Callable which accepts :attr:`.content` as argument,
1087 Callable which accepts :attr:`.content` as argument,
1091 parses it and returns the parsed data as :class:`dict`.
1088 parses it and returns the parsed data as :class:`dict`.
1092 """
1089 """
1093
1090
1094 self.httplib_response = httplib_response
1091 self.httplib_response = httplib_response
1095 self.content_parser = content_parser or json_qs_parser
1092 self.content_parser = content_parser or json_qs_parser
1096 self._data = None
1093 self._data = None
1097 self._content = None
1094 self._content = None
1098
1095
1099 #: Same as :attr:`httplib.HTTPResponse.msg`.
1096 #: Same as :attr:`httplib.HTTPResponse.msg`.
1100 self.msg = httplib_response.msg
1097 self.msg = httplib_response.msg
1101 #: Same as :attr:`httplib.HTTPResponse.version`.
1098 #: Same as :attr:`httplib.HTTPResponse.version`.
1102 self.version = httplib_response.version
1099 self.version = httplib_response.version
1103 #: Same as :attr:`httplib.HTTPResponse.status`.
1100 #: Same as :attr:`httplib.HTTPResponse.status`.
1104 self.status = httplib_response.status
1101 self.status = httplib_response.status
1105 #: Same as :attr:`httplib.HTTPResponse.reason`.
1102 #: Same as :attr:`httplib.HTTPResponse.reason`.
1106 self.reason = httplib_response.reason
1103 self.reason = httplib_response.reason
1107
1104
1108 def read(self, amt=None):
1105 def read(self, amt=None):
1109 """
1106 """
1110 Same as :meth:`httplib.HTTPResponse.read`.
1107 Same as :meth:`httplib.HTTPResponse.read`.
1111
1108
1112 :param amt:
1109 :param amt:
1113
1110
1114 """
1111 """
1115
1112
1116 return self.httplib_response.read(amt)
1113 return self.httplib_response.read(amt)
1117
1114
1118 def getheader(self, name, default=None):
1115 def getheader(self, name, default=None):
1119 """
1116 """
1120 Same as :meth:`httplib.HTTPResponse.getheader`.
1117 Same as :meth:`httplib.HTTPResponse.getheader`.
1121
1118
1122 :param name:
1119 :param name:
1123 :param default:
1120 :param default:
1124
1121
1125 """
1122 """
1126
1123
1127 return self.httplib_response.getheader(name, default)
1124 return self.httplib_response.getheader(name, default)
1128
1125
1129 def fileno(self):
1126 def fileno(self):
1130 """
1127 """
1131 Same as :meth:`httplib.HTTPResponse.fileno`.
1128 Same as :meth:`httplib.HTTPResponse.fileno`.
1132 """
1129 """
1133 return self.httplib_response.fileno()
1130 return self.httplib_response.fileno()
1134
1131
1135 def getheaders(self):
1132 def getheaders(self):
1136 """
1133 """
1137 Same as :meth:`httplib.HTTPResponse.getheaders`.
1134 Same as :meth:`httplib.HTTPResponse.getheaders`.
1138 """
1135 """
1139 return self.httplib_response.getheaders()
1136 return self.httplib_response.getheaders()
1140
1137
1141 @staticmethod
1138 @staticmethod
1142 def is_binary_string(content):
1139 def is_binary_string(content):
1143 """
1140 """
1144 Return true if string is binary data.
1141 Return true if string is binary data.
1145 """
1142 """
1146
1143
1147 textchars = (bytearray([7, 8, 9, 10, 12, 13, 27]) +
1144 textchars = (bytearray([7, 8, 9, 10, 12, 13, 27]) +
1148 bytearray(range(0x20, 0x100)))
1145 bytearray(range(0x20, 0x100)))
1149 return bool(content.translate(None, textchars))
1146 return bool(content.translate(None, textchars))
1150
1147
1151 @property
1148 @property
1152 def content(self):
1149 def content(self):
1153 """
1150 """
1154 The whole response content.
1151 The whole response content.
1155 """
1152 """
1156
1153
1157 if not self._content:
1154 if not self._content:
1158 content = self.httplib_response.read()
1155 content = self.httplib_response.read()
1159 if self.is_binary_string(content):
1156 if self.is_binary_string(content):
1160 self._content = content
1157 self._content = content
1161 else:
1158 else:
1162 self._content = content.decode('utf-8')
1159 self._content = content.decode('utf-8')
1163 return self._content
1160 return self._content
1164
1161
1165 @property
1162 @property
1166 def data(self):
1163 def data(self):
1167 """
1164 """
1168 A :class:`dict` of data parsed from :attr:`.content`.
1165 A :class:`dict` of data parsed from :attr:`.content`.
1169 """
1166 """
1170
1167
1171 if not self._data:
1168 if not self._data:
1172 self._data = self.content_parser(self.content)
1169 self._data = self.content_parser(self.content)
1173 return self._data
1170 return self._data
1174
1171
1175
1172
1176 class UserInfoResponse(Response):
1173 class UserInfoResponse(Response):
1177 """
1174 """
1178 Inherits from :class:`.Response`, adds :attr:`~UserInfoResponse.user`
1175 Inherits from :class:`.Response`, adds :attr:`~UserInfoResponse.user`
1179 attribute.
1176 attribute.
1180 """
1177 """
1181
1178
1182 def __init__(self, user, *args, **kwargs):
1179 def __init__(self, user, *args, **kwargs):
1183 super(UserInfoResponse, self).__init__(*args, **kwargs)
1180 super(UserInfoResponse, self).__init__(*args, **kwargs)
1184
1181
1185 #: :class:`.User` instance.
1182 #: :class:`.User` instance.
1186 self.user = user
1183 self.user = user
1187
1184
1188
1185
1189 class RequestElements(tuple):
1186 class RequestElements(tuple):
1190 """
1187 """
1191 A tuple of ``(url, method, params, headers, body)`` request elements.
1188 A tuple of ``(url, method, params, headers, body)`` request elements.
1192
1189
1193 With some additional properties.
1190 With some additional properties.
1194
1191
1195 """
1192 """
1196
1193
1197 def __new__(cls, url, method, params, headers, body):
1194 def __new__(cls, url, method, params, headers, body):
1198 return tuple.__new__(cls, (url, method, params, headers, body))
1195 return tuple.__new__(cls, (url, method, params, headers, body))
1199
1196
1200 @property
1197 @property
1201 def url(self):
1198 def url(self):
1202 """
1199 """
1203 Request URL.
1200 Request URL.
1204 """
1201 """
1205
1202
1206 return self[0]
1203 return self[0]
1207
1204
1208 @property
1205 @property
1209 def method(self):
1206 def method(self):
1210 """
1207 """
1211 HTTP method of the request.
1208 HTTP method of the request.
1212 """
1209 """
1213
1210
1214 return self[1]
1211 return self[1]
1215
1212
1216 @property
1213 @property
1217 def params(self):
1214 def params(self):
1218 """
1215 """
1219 Dictionary of request parameters.
1216 Dictionary of request parameters.
1220 """
1217 """
1221
1218
1222 return self[2]
1219 return self[2]
1223
1220
1224 @property
1221 @property
1225 def headers(self):
1222 def headers(self):
1226 """
1223 """
1227 Dictionary of request headers.
1224 Dictionary of request headers.
1228 """
1225 """
1229
1226
1230 return self[3]
1227 return self[3]
1231
1228
1232 @property
1229 @property
1233 def body(self):
1230 def body(self):
1234 """
1231 """
1235 :class:`str` Body of ``POST``, ``PUT`` and ``PATCH`` requests.
1232 :class:`str` Body of ``POST``, ``PUT`` and ``PATCH`` requests.
1236 """
1233 """
1237
1234
1238 return self[4]
1235 return self[4]
1239
1236
1240 @property
1237 @property
1241 def query_string(self):
1238 def query_string(self):
1242 """
1239 """
1243 Query string of the request.
1240 Query string of the request.
1244 """
1241 """
1245
1242
1246 return parse.urlencode(self.params)
1243 return parse.urlencode(self.params)
1247
1244
1248 @property
1245 @property
1249 def full_url(self):
1246 def full_url(self):
1250 """
1247 """
1251 URL with query string.
1248 URL with query string.
1252 """
1249 """
1253
1250
1254 return self.url + '?' + self.query_string
1251 return self.url + '?' + self.query_string
1255
1252
1256 def to_json(self):
1253 def to_json(self):
1257 return json.dumps(dict(url=self.url,
1254 return json.dumps(dict(url=self.url,
1258 method=self.method,
1255 method=self.method,
1259 params=self.params,
1256 params=self.params,
1260 headers=self.headers,
1257 headers=self.headers,
1261 body=self.body))
1258 body=self.body))
1262
1259
1263
1260
1264 class Authomatic(object):
1261 class Authomatic(object):
1265 def __init__(
1262 def __init__(
1266 self, config, secret, session_max_age=600, secure_cookie=False,
1263 self, config, secret, session_max_age=600, secure_cookie=False,
1267 session=None, session_save_method=None, report_errors=True,
1264 session=None, session_save_method=None, report_errors=True,
1268 debug=False, logging_level=logging.INFO, prefix='authomatic',
1265 debug=False, logging_level=logging.INFO, prefix='authomatic',
1269 logger=None
1266 logger=None
1270 ):
1267 ):
1271 """
1268 """
1272 Encapsulates all the functionality of this package.
1269 Encapsulates all the functionality of this package.
1273
1270
1274 :param dict config:
1271 :param dict config:
1275 :doc:`config`
1272 :doc:`config`
1276
1273
1277 :param str secret:
1274 :param str secret:
1278 A secret string that will be used as the key for signing
1275 A secret string that will be used as the key for signing
1279 :class:`.Session` cookie and as a salt by *CSRF* token generation.
1276 :class:`.Session` cookie and as a salt by *CSRF* token generation.
1280
1277
1281 :param session_max_age:
1278 :param session_max_age:
1282 Maximum allowed age of :class:`.Session` cookie nonce in seconds.
1279 Maximum allowed age of :class:`.Session` cookie nonce in seconds.
1283
1280
1284 :param bool secure_cookie:
1281 :param bool secure_cookie:
1285 If ``True`` the :class:`.Session` cookie will be saved wit
1282 If ``True`` the :class:`.Session` cookie will be saved wit
1286 ``Secure`` attribute.
1283 ``Secure`` attribute.
1287
1284
1288 :param session:
1285 :param session:
1289 Custom dictionary-like session implementation.
1286 Custom dictionary-like session implementation.
1290
1287
1291 :param callable session_save_method:
1288 :param callable session_save_method:
1292 A method of the supplied session or any mechanism that saves the
1289 A method of the supplied session or any mechanism that saves the
1293 session data and cookie.
1290 session data and cookie.
1294
1291
1295 :param bool report_errors:
1292 :param bool report_errors:
1296 If ``True`` exceptions encountered during the **login procedure**
1293 If ``True`` exceptions encountered during the **login procedure**
1297 will be caught and reported in the :attr:`.LoginResult.error`
1294 will be caught and reported in the :attr:`.LoginResult.error`
1298 attribute.
1295 attribute.
1299 Default is ``True``.
1296 Default is ``True``.
1300
1297
1301 :param bool debug:
1298 :param bool debug:
1302 If ``True`` traceback of exceptions will be written to response.
1299 If ``True`` traceback of exceptions will be written to response.
1303 Default is ``False``.
1300 Default is ``False``.
1304
1301
1305 :param int logging_level:
1302 :param int logging_level:
1306 The logging level threshold for the default logger as specified in
1303 The logging level threshold for the default logger as specified in
1307 the standard Python
1304 the standard Python
1308 `logging library <http://docs.python.org/2/library/logging.html>`_.
1305 `logging library <http://docs.python.org/2/library/logging.html>`_.
1309 This setting is ignored when :data:`logger` is set.
1306 This setting is ignored when :data:`logger` is set.
1310 Default is ``logging.INFO``.
1307 Default is ``logging.INFO``.
1311
1308
1312 :param str prefix:
1309 :param str prefix:
1313 Prefix used as the :class:`.Session` cookie name.
1310 Prefix used as the :class:`.Session` cookie name.
1314
1311
1315 :param logger:
1312 :param logger:
1316 A :class:`logging.logger` instance.
1313 A :class:`logging.logger` instance.
1317
1314
1318 """
1315 """
1319
1316
1320 self.config = config
1317 self.config = config
1321 self.secret = secret
1318 self.secret = secret
1322 self.session_max_age = session_max_age
1319 self.session_max_age = session_max_age
1323 self.secure_cookie = secure_cookie
1320 self.secure_cookie = secure_cookie
1324 self.session = session
1321 self.session = session
1325 self.session_save_method = session_save_method
1322 self.session_save_method = session_save_method
1326 self.report_errors = report_errors
1323 self.report_errors = report_errors
1327 self.debug = debug
1324 self.debug = debug
1328 self.logging_level = logging_level
1325 self.logging_level = logging_level
1329 self.prefix = prefix
1326 self.prefix = prefix
1330 self._logger = logger or logging.getLogger(str(id(self)))
1327 self._logger = logger or logging.getLogger(str(id(self)))
1331
1328
1332 # Set logging level.
1329 # Set logging level.
1333 if logger is None:
1330 if logger is None:
1334 self._logger.setLevel(logging_level)
1331 self._logger.setLevel(logging_level)
1335
1332
1336 def login(self, adapter, provider_name, callback=None,
1333 def login(self, adapter, provider_name, callback=None,
1337 session=None, session_saver=None, **kwargs):
1334 session=None, session_saver=None, **kwargs):
1338 """
1335 """
1339 If :data:`provider_name` specified, launches the login procedure for
1336 If :data:`provider_name` specified, launches the login procedure for
1340 corresponding :doc:`provider </reference/providers>` and returns
1337 corresponding :doc:`provider </reference/providers>` and returns
1341 :class:`.LoginResult`.
1338 :class:`.LoginResult`.
1342
1339
1343 If :data:`provider_name` is empty, acts like
1340 If :data:`provider_name` is empty, acts like
1344 :meth:`.Authomatic.backend`.
1341 :meth:`.Authomatic.backend`.
1345
1342
1346 .. warning::
1343 .. warning::
1347
1344
1348 The method redirects the **user** to the **provider** which in
1345 The method redirects the **user** to the **provider** which in
1349 turn redirects **him/her** back to the *request handler* where
1346 turn redirects **him/her** back to the *request handler* where
1350 it has been called.
1347 it has been called.
1351
1348
1352 :param str provider_name:
1349 :param str provider_name:
1353 Name of the provider as specified in the keys of the :doc:`config`.
1350 Name of the provider as specified in the keys of the :doc:`config`.
1354
1351
1355 :param callable callback:
1352 :param callable callback:
1356 If specified the method will call the callback with
1353 If specified the method will call the callback with
1357 :class:`.LoginResult` passed as argument and will return nothing.
1354 :class:`.LoginResult` passed as argument and will return nothing.
1358
1355
1359 :param bool report_errors:
1356 :param bool report_errors:
1360
1357
1361 .. note::
1358 .. note::
1362
1359
1363 Accepts additional keyword arguments that will be passed to
1360 Accepts additional keyword arguments that will be passed to
1364 :doc:`provider <providers>` constructor.
1361 :doc:`provider <providers>` constructor.
1365
1362
1366 :returns:
1363 :returns:
1367 :class:`.LoginResult`
1364 :class:`.LoginResult`
1368
1365
1369 """
1366 """
1370
1367
1371 if provider_name:
1368 if provider_name:
1372 # retrieve required settings for current provider and raise
1369 # retrieve required settings for current provider and raise
1373 # exceptions if missing
1370 # exceptions if missing
1374 provider_settings = self.config.get(provider_name)
1371 provider_settings = self.config.get(provider_name)
1375 if not provider_settings:
1372 if not provider_settings:
1376 raise ConfigError('Provider name "{0}" not specified!'
1373 raise ConfigError('Provider name "{0}" not specified!'
1377 .format(provider_name))
1374 .format(provider_name))
1378
1375
1379 if not (session is None or session_saver is None):
1376 if not (session is None or session_saver is None):
1380 session = session
1377 session = session
1381 session_saver = session_saver
1378 session_saver = session_saver
1382 else:
1379 else:
1383 session = Session(adapter=adapter,
1380 session = Session(adapter=adapter,
1384 secret=self.secret,
1381 secret=self.secret,
1385 max_age=self.session_max_age,
1382 max_age=self.session_max_age,
1386 name=self.prefix,
1383 name=self.prefix,
1387 secure=self.secure_cookie)
1384 secure=self.secure_cookie)
1388
1385
1389 session_saver = session.save
1386 session_saver = session.save
1390
1387
1391 # Resolve provider class.
1388 # Resolve provider class.
1392 class_ = provider_settings.get('class_')
1389 class_ = provider_settings.get('class_')
1393 if not class_:
1390 if not class_:
1394 raise ConfigError(
1391 raise ConfigError(
1395 'The "class_" key not specified in the config'
1392 'The "class_" key not specified in the config'
1396 ' for provider {0}!'.format(provider_name))
1393 ' for provider {0}!'.format(provider_name))
1397 ProviderClass = resolve_provider_class(class_)
1394 ProviderClass = resolve_provider_class(class_)
1398
1395
1399 # FIXME: Find a nicer solution
1396 # FIXME: Find a nicer solution
1400 ProviderClass._logger = self._logger
1397 ProviderClass._logger = self._logger
1401
1398
1402 # instantiate provider class
1399 # instantiate provider class
1403 provider = ProviderClass(self,
1400 provider = ProviderClass(self,
1404 adapter=adapter,
1401 adapter=adapter,
1405 provider_name=provider_name,
1402 provider_name=provider_name,
1406 callback=callback,
1403 callback=callback,
1407 session=session,
1404 session=session,
1408 session_saver=session_saver,
1405 session_saver=session_saver,
1409 **kwargs)
1406 **kwargs)
1410
1407
1411 # return login result
1408 # return login result
1412 return provider.login()
1409 return provider.login()
1413
1410
1414 else:
1411 else:
1415 # Act like backend.
1412 # Act like backend.
1416 self.backend(adapter)
1413 self.backend(adapter)
1417
1414
1418 def credentials(self, credentials):
1415 def credentials(self, credentials):
1419 """
1416 """
1420 Deserializes credentials.
1417 Deserializes credentials.
1421
1418
1422 :param credentials:
1419 :param credentials:
1423 Credentials serialized with :meth:`.Credentials.serialize` or
1420 Credentials serialized with :meth:`.Credentials.serialize` or
1424 :class:`.Credentials` instance.
1421 :class:`.Credentials` instance.
1425
1422
1426 :returns:
1423 :returns:
1427 :class:`.Credentials`
1424 :class:`.Credentials`
1428
1425
1429 """
1426 """
1430
1427
1431 return Credentials.deserialize(self.config, credentials)
1428 return Credentials.deserialize(self.config, credentials)
1432
1429
1433 def access(self, credentials, url, params=None, method='GET',
1430 def access(self, credentials, url, params=None, method='GET',
1434 headers=None, body='', max_redirects=5, content_parser=None):
1431 headers=None, body='', max_redirects=5, content_parser=None):
1435 """
1432 """
1436 Accesses **protected resource** on behalf of the **user**.
1433 Accesses **protected resource** on behalf of the **user**.
1437
1434
1438 :param credentials:
1435 :param credentials:
1439 The **user's** :class:`.Credentials` (serialized or normal).
1436 The **user's** :class:`.Credentials` (serialized or normal).
1440
1437
1441 :param str url:
1438 :param str url:
1442 The **protected resource** URL.
1439 The **protected resource** URL.
1443
1440
1444 :param str method:
1441 :param str method:
1445 HTTP method of the request.
1442 HTTP method of the request.
1446
1443
1447 :param dict headers:
1444 :param dict headers:
1448 HTTP headers of the request.
1445 HTTP headers of the request.
1449
1446
1450 :param str body:
1447 :param str body:
1451 Body of ``POST``, ``PUT`` and ``PATCH`` requests.
1448 Body of ``POST``, ``PUT`` and ``PATCH`` requests.
1452
1449
1453 :param int max_redirects:
1450 :param int max_redirects:
1454 Maximum number of HTTP redirects to follow.
1451 Maximum number of HTTP redirects to follow.
1455
1452
1456 :param function content_parser:
1453 :param function content_parser:
1457 A function to be used to parse the :attr:`.Response.data`
1454 A function to be used to parse the :attr:`.Response.data`
1458 from :attr:`.Response.content`.
1455 from :attr:`.Response.content`.
1459
1456
1460 :returns:
1457 :returns:
1461 :class:`.Response`
1458 :class:`.Response`
1462
1459
1463 """
1460 """
1464
1461
1465 # Deserialize credentials.
1462 # Deserialize credentials.
1466 credentials = Credentials.deserialize(self.config, credentials)
1463 credentials = Credentials.deserialize(self.config, credentials)
1467
1464
1468 # Resolve provider class.
1465 # Resolve provider class.
1469 ProviderClass = credentials.provider_class
1466 ProviderClass = credentials.provider_class
1470 logging.info('ACCESS HEADERS: {0}'.format(headers))
1467 logging.info('ACCESS HEADERS: {0}'.format(headers))
1471 # Access resource and return response.
1468 # Access resource and return response.
1472
1469
1473 provider = ProviderClass(
1470 provider = ProviderClass(
1474 self, adapter=None, provider_name=credentials.provider_name)
1471 self, adapter=None, provider_name=credentials.provider_name)
1475 provider.credentials = credentials
1472 provider.credentials = credentials
1476
1473
1477 return provider.access(url=url,
1474 return provider.access(url=url,
1478 params=params,
1475 params=params,
1479 method=method,
1476 method=method,
1480 headers=headers,
1477 headers=headers,
1481 body=body,
1478 body=body,
1482 max_redirects=max_redirects,
1479 max_redirects=max_redirects,
1483 content_parser=content_parser)
1480 content_parser=content_parser)
1484
1481
1485 def async_access(self, *args, **kwargs):
1482 def async_access(self, *args, **kwargs):
1486 """
1483 """
1487 Same as :meth:`.Authomatic.access` but runs asynchronously in a
1484 Same as :meth:`.Authomatic.access` but runs asynchronously in a
1488 separate thread.
1485 separate thread.
1489
1486
1490 .. warning::
1487 .. warning::
1491
1488
1492 |async|
1489 |async|
1493
1490
1494 :returns:
1491 :returns:
1495 :class:`.Future` instance representing the separate thread.
1492 :class:`.Future` instance representing the separate thread.
1496
1493
1497 """
1494 """
1498
1495
1499 return Future(self.access, *args, **kwargs)
1496 return Future(self.access, *args, **kwargs)
1500
1497
1501 def request_elements(
1498 def request_elements(
1502 self, credentials=None, url=None, method='GET', params=None,
1499 self, credentials=None, url=None, method='GET', params=None,
1503 headers=None, body='', json_input=None, return_json=False
1500 headers=None, body='', json_input=None, return_json=False
1504 ):
1501 ):
1505 """
1502 """
1506 Creates request elements for accessing **protected resource of a
1503 Creates request elements for accessing **protected resource of a
1507 user**. Required arguments are :data:`credentials` and :data:`url`. You
1504 user**. Required arguments are :data:`credentials` and :data:`url`. You
1508 can pass :data:`credentials`, :data:`url`, :data:`method`, and
1505 can pass :data:`credentials`, :data:`url`, :data:`method`, and
1509 :data:`params` as a JSON object.
1506 :data:`params` as a JSON object.
1510
1507
1511 :param credentials:
1508 :param credentials:
1512 The **user's** credentials (can be serialized).
1509 The **user's** credentials (can be serialized).
1513
1510
1514 :param str url:
1511 :param str url:
1515 The url of the protected resource.
1512 The url of the protected resource.
1516
1513
1517 :param str method:
1514 :param str method:
1518 The HTTP method of the request.
1515 The HTTP method of the request.
1519
1516
1520 :param dict params:
1517 :param dict params:
1521 Dictionary of request parameters.
1518 Dictionary of request parameters.
1522
1519
1523 :param dict headers:
1520 :param dict headers:
1524 Dictionary of request headers.
1521 Dictionary of request headers.
1525
1522
1526 :param str body:
1523 :param str body:
1527 Body of ``POST``, ``PUT`` and ``PATCH`` requests.
1524 Body of ``POST``, ``PUT`` and ``PATCH`` requests.
1528
1525
1529 :param str json_input:
1526 :param str json_input:
1530 you can pass :data:`credentials`, :data:`url`, :data:`method`,
1527 you can pass :data:`credentials`, :data:`url`, :data:`method`,
1531 :data:`params` and :data:`headers` in a JSON object.
1528 :data:`params` and :data:`headers` in a JSON object.
1532 Values from arguments will be used for missing properties.
1529 Values from arguments will be used for missing properties.
1533
1530
1534 ::
1531 ::
1535
1532
1536 {
1533 {
1537 "credentials": "###",
1534 "credentials": "###",
1538 "url": "https://example.com/api",
1535 "url": "https://example.com/api",
1539 "method": "POST",
1536 "method": "POST",
1540 "params": {
1537 "params": {
1541 "foo": "bar"
1538 "foo": "bar"
1542 },
1539 },
1543 "headers": {
1540 "headers": {
1544 "baz": "bing",
1541 "baz": "bing",
1545 "Authorization": "Bearer ###"
1542 "Authorization": "Bearer ###"
1546 },
1543 },
1547 "body": "Foo bar baz bing."
1544 "body": "Foo bar baz bing."
1548 }
1545 }
1549
1546
1550 :param bool return_json:
1547 :param bool return_json:
1551 if ``True`` the function returns a json object.
1548 if ``True`` the function returns a json object.
1552
1549
1553 ::
1550 ::
1554
1551
1555 {
1552 {
1556 "url": "https://example.com/api",
1553 "url": "https://example.com/api",
1557 "method": "POST",
1554 "method": "POST",
1558 "params": {
1555 "params": {
1559 "access_token": "###",
1556 "access_token": "###",
1560 "foo": "bar"
1557 "foo": "bar"
1561 },
1558 },
1562 "headers": {
1559 "headers": {
1563 "baz": "bing",
1560 "baz": "bing",
1564 "Authorization": "Bearer ###"
1561 "Authorization": "Bearer ###"
1565 },
1562 },
1566 "body": "Foo bar baz bing."
1563 "body": "Foo bar baz bing."
1567 }
1564 }
1568
1565
1569 :returns:
1566 :returns:
1570 :class:`.RequestElements` or JSON string.
1567 :class:`.RequestElements` or JSON string.
1571
1568
1572 """
1569 """
1573
1570
1574 # Parse values from JSON
1571 # Parse values from JSON
1575 if json_input:
1572 if json_input:
1576 parsed_input = json.loads(json_input)
1573 parsed_input = json.loads(json_input)
1577
1574
1578 credentials = parsed_input.get('credentials', credentials)
1575 credentials = parsed_input.get('credentials', credentials)
1579 url = parsed_input.get('url', url)
1576 url = parsed_input.get('url', url)
1580 method = parsed_input.get('method', method)
1577 method = parsed_input.get('method', method)
1581 params = parsed_input.get('params', params)
1578 params = parsed_input.get('params', params)
1582 headers = parsed_input.get('headers', headers)
1579 headers = parsed_input.get('headers', headers)
1583 body = parsed_input.get('body', body)
1580 body = parsed_input.get('body', body)
1584
1581
1585 if not credentials and url:
1582 if not credentials and url:
1586 raise RequestElementsError(
1583 raise RequestElementsError(
1587 'To create request elements, you must provide credentials '
1584 'To create request elements, you must provide credentials '
1588 'and URL either as keyword arguments or in the JSON object!')
1585 'and URL either as keyword arguments or in the JSON object!')
1589
1586
1590 # Get the provider class
1587 # Get the provider class
1591 credentials = Credentials.deserialize(self.config, credentials)
1588 credentials = Credentials.deserialize(self.config, credentials)
1592 ProviderClass = credentials.provider_class
1589 ProviderClass = credentials.provider_class
1593
1590
1594 # Create request elements
1591 # Create request elements
1595 request_elements = ProviderClass.create_request_elements(
1592 request_elements = ProviderClass.create_request_elements(
1596 ProviderClass.PROTECTED_RESOURCE_REQUEST_TYPE,
1593 ProviderClass.PROTECTED_RESOURCE_REQUEST_TYPE,
1597 credentials=credentials,
1594 credentials=credentials,
1598 url=url,
1595 url=url,
1599 method=method,
1596 method=method,
1600 params=params,
1597 params=params,
1601 headers=headers,
1598 headers=headers,
1602 body=body)
1599 body=body)
1603
1600
1604 if return_json:
1601 if return_json:
1605 return request_elements.to_json()
1602 return request_elements.to_json()
1606
1603
1607 else:
1604 else:
1608 return request_elements
1605 return request_elements
1609
1606
1610 def backend(self, adapter):
1607 def backend(self, adapter):
1611 """
1608 """
1612 Converts a *request handler* to a JSON backend which you can use with
1609 Converts a *request handler* to a JSON backend which you can use with
1613 :ref:`authomatic.js <js>`.
1610 :ref:`authomatic.js <js>`.
1614
1611
1615 Just call it inside a *request handler* like this:
1612 Just call it inside a *request handler* like this:
1616
1613
1617 ::
1614 ::
1618
1615
1619 class JSONHandler(webapp2.RequestHandler):
1616 class JSONHandler(webapp2.RequestHandler):
1620 def get(self):
1617 def get(self):
1621 authomatic.backend(Webapp2Adapter(self))
1618 authomatic.backend(Webapp2Adapter(self))
1622
1619
1623 :param adapter:
1620 :param adapter:
1624 The only argument is an :doc:`adapter <adapters>`.
1621 The only argument is an :doc:`adapter <adapters>`.
1625
1622
1626 The *request handler* will now accept these request parameters:
1623 The *request handler* will now accept these request parameters:
1627
1624
1628 :param str type:
1625 :param str type:
1629 Type of the request. Either ``auto``, ``fetch`` or ``elements``.
1626 Type of the request. Either ``auto``, ``fetch`` or ``elements``.
1630 Default is ``auto``.
1627 Default is ``auto``.
1631
1628
1632 :param str credentials:
1629 :param str credentials:
1633 Serialized :class:`.Credentials`.
1630 Serialized :class:`.Credentials`.
1634
1631
1635 :param str url:
1632 :param str url:
1636 URL of the **protected resource** request.
1633 URL of the **protected resource** request.
1637
1634
1638 :param str method:
1635 :param str method:
1639 HTTP method of the **protected resource** request.
1636 HTTP method of the **protected resource** request.
1640
1637
1641 :param str body:
1638 :param str body:
1642 HTTP body of the **protected resource** request.
1639 HTTP body of the **protected resource** request.
1643
1640
1644 :param JSON params:
1641 :param JSON params:
1645 HTTP params of the **protected resource** request as a JSON object.
1642 HTTP params of the **protected resource** request as a JSON object.
1646
1643
1647 :param JSON headers:
1644 :param JSON headers:
1648 HTTP headers of the **protected resource** request as a
1645 HTTP headers of the **protected resource** request as a
1649 JSON object.
1646 JSON object.
1650
1647
1651 :param JSON json:
1648 :param JSON json:
1652 You can pass all of the aforementioned params except ``type``
1649 You can pass all of the aforementioned params except ``type``
1653 in a JSON object.
1650 in a JSON object.
1654
1651
1655 .. code-block:: javascript
1652 .. code-block:: javascript
1656
1653
1657 {
1654 {
1658 "credentials": "######",
1655 "credentials": "######",
1659 "url": "https://example.com",
1656 "url": "https://example.com",
1660 "method": "POST",
1657 "method": "POST",
1661 "params": {"foo": "bar"},
1658 "params": {"foo": "bar"},
1662 "headers": {"baz": "bing"},
1659 "headers": {"baz": "bing"},
1663 "body": "the body of the request"
1660 "body": "the body of the request"
1664 }
1661 }
1665
1662
1666 Depending on the ``type`` param, the handler will either write
1663 Depending on the ``type`` param, the handler will either write
1667 a JSON object with *request elements* to the response,
1664 a JSON object with *request elements* to the response,
1668 and add an ``Authomatic-Response-To: elements`` response header, ...
1665 and add an ``Authomatic-Response-To: elements`` response header, ...
1669
1666
1670 .. code-block:: javascript
1667 .. code-block:: javascript
1671
1668
1672 {
1669 {
1673 "url": "https://example.com/api",
1670 "url": "https://example.com/api",
1674 "method": "POST",
1671 "method": "POST",
1675 "params": {
1672 "params": {
1676 "access_token": "###",
1673 "access_token": "###",
1677 "foo": "bar"
1674 "foo": "bar"
1678 },
1675 },
1679 "headers": {
1676 "headers": {
1680 "baz": "bing",
1677 "baz": "bing",
1681 "Authorization": "Bearer ###"
1678 "Authorization": "Bearer ###"
1682 }
1679 }
1683 }
1680 }
1684
1681
1685 ... or make a fetch to the **protected resource** and forward
1682 ... or make a fetch to the **protected resource** and forward
1686 it's response content, status and headers with an additional
1683 it's response content, status and headers with an additional
1687 ``Authomatic-Response-To: fetch`` header to the response.
1684 ``Authomatic-Response-To: fetch`` header to the response.
1688
1685
1689 .. warning::
1686 .. warning::
1690
1687
1691 The backend will not work if you write anything to the
1688 The backend will not work if you write anything to the
1692 response in the handler!
1689 response in the handler!
1693
1690
1694 """
1691 """
1695
1692
1696 AUTHOMATIC_HEADER = 'Authomatic-Response-To'
1693 AUTHOMATIC_HEADER = 'Authomatic-Response-To'
1697
1694
1698 # Collect request params
1695 # Collect request params
1699 request_type = adapter.params.get('type', 'auto')
1696 request_type = adapter.params.get('type', 'auto')
1700 json_input = adapter.params.get('json')
1697 json_input = adapter.params.get('json')
1701 credentials = adapter.params.get('credentials')
1698 credentials = adapter.params.get('credentials')
1702 url = adapter.params.get('url')
1699 url = adapter.params.get('url')
1703 method = adapter.params.get('method', 'GET')
1700 method = adapter.params.get('method', 'GET')
1704 body = adapter.params.get('body', '')
1701 body = adapter.params.get('body', '')
1705
1702
1706 params = adapter.params.get('params')
1703 params = adapter.params.get('params')
1707 params = json.loads(params) if params else {}
1704 params = json.loads(params) if params else {}
1708
1705
1709 headers = adapter.params.get('headers')
1706 headers = adapter.params.get('headers')
1710 headers = json.loads(headers) if headers else {}
1707 headers = json.loads(headers) if headers else {}
1711
1708
1712 ProviderClass = Credentials.deserialize(
1709 ProviderClass = Credentials.deserialize(
1713 self.config, credentials).provider_class
1710 self.config, credentials).provider_class
1714
1711
1715 if request_type == 'auto':
1712 if request_type == 'auto':
1716 # If there is a "callback" param, it's a JSONP request.
1713 # If there is a "callback" param, it's a JSONP request.
1717 jsonp = params.get('callback')
1714 jsonp = params.get('callback')
1718
1715
1719 # JSONP is possible only with GET method.
1716 # JSONP is possible only with GET method.
1720 if ProviderClass.supports_jsonp and method is 'GET':
1717 if ProviderClass.supports_jsonp and method is 'GET':
1721 request_type = 'elements'
1718 request_type = 'elements'
1722 else:
1719 else:
1723 # Remove the JSONP callback
1720 # Remove the JSONP callback
1724 if jsonp:
1721 if jsonp:
1725 params.pop('callback')
1722 params.pop('callback')
1726 request_type = 'fetch'
1723 request_type = 'fetch'
1727
1724
1728 if request_type == 'fetch':
1725 if request_type == 'fetch':
1729 # Access protected resource
1726 # Access protected resource
1730 response = self.access(
1727 response = self.access(
1731 credentials, url, params, method, headers, body)
1728 credentials, url, params, method, headers, body)
1732 result = response.content
1729 result = response.content
1733
1730
1734 # Forward status
1731 # Forward status
1735 adapter.status = str(response.status) + ' ' + str(response.reason)
1732 adapter.status = str(response.status) + ' ' + str(response.reason)
1736
1733
1737 # Forward headers
1734 # Forward headers
1738 for k, v in response.getheaders():
1735 for k, v in response.getheaders():
1739 logging.info(' {0}: {1}'.format(k, v))
1736 logging.info(' {0}: {1}'.format(k, v))
1740 adapter.set_header(k, v)
1737 adapter.set_header(k, v)
1741
1738
1742 elif request_type == 'elements':
1739 elif request_type == 'elements':
1743 # Create request elements
1740 # Create request elements
1744 if json_input:
1741 if json_input:
1745 result = self.request_elements(
1742 result = self.request_elements(
1746 json_input=json_input, return_json=True)
1743 json_input=json_input, return_json=True)
1747 else:
1744 else:
1748 result = self.request_elements(credentials=credentials,
1745 result = self.request_elements(credentials=credentials,
1749 url=url,
1746 url=url,
1750 method=method,
1747 method=method,
1751 params=params,
1748 params=params,
1752 headers=headers,
1749 headers=headers,
1753 body=body,
1750 body=body,
1754 return_json=True)
1751 return_json=True)
1755
1752
1756 adapter.set_header('Content-Type', 'application/json')
1753 adapter.set_header('Content-Type', 'application/json')
1757 else:
1754 else:
1758 result = '{"error": "Bad Request!"}'
1755 result = '{"error": "Bad Request!"}'
1759
1756
1760 # Add the authomatic header
1757 # Add the authomatic header
1761 adapter.set_header(AUTHOMATIC_HEADER, request_type)
1758 adapter.set_header(AUTHOMATIC_HEADER, request_type)
1762
1759
1763 # Write result to response
1760 # Write result to response
1764 adapter.write(result)
1761 adapter.write(result)
@@ -1,1272 +1,1272 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 Set of diffing helpers, previously part of vcs
23 Set of diffing helpers, previously part of vcs
24 """
24 """
25
25
26 import os
26 import os
27 import re
27 import re
28 import bz2
28 import bz2
29 import gzip
29 import gzip
30 import time
30 import time
31
31
32 import collections
32 import collections
33 import difflib
33 import difflib
34 import logging
34 import logging
35 import cPickle as pickle
35 import pickle
36 from itertools import tee, imap
36 from itertools import tee
37
37
38 from rhodecode.lib.vcs.exceptions import VCSError
38 from rhodecode.lib.vcs.exceptions import VCSError
39 from rhodecode.lib.vcs.nodes import FileNode, SubModuleNode
39 from rhodecode.lib.vcs.nodes import FileNode, SubModuleNode
40 from rhodecode.lib.utils2 import safe_unicode, safe_str
40 from rhodecode.lib.utils2 import safe_unicode, safe_str
41
41
42 log = logging.getLogger(__name__)
42 log = logging.getLogger(__name__)
43
43
44 # define max context, a file with more than this numbers of lines is unusable
44 # define max context, a file with more than this numbers of lines is unusable
45 # in browser anyway
45 # in browser anyway
46 MAX_CONTEXT = 20 * 1024
46 MAX_CONTEXT = 20 * 1024
47 DEFAULT_CONTEXT = 3
47 DEFAULT_CONTEXT = 3
48
48
49
49
50 def get_diff_context(request):
50 def get_diff_context(request):
51 return MAX_CONTEXT if request.GET.get('fullcontext', '') == '1' else DEFAULT_CONTEXT
51 return MAX_CONTEXT if request.GET.get('fullcontext', '') == '1' else DEFAULT_CONTEXT
52
52
53
53
54 def get_diff_whitespace_flag(request):
54 def get_diff_whitespace_flag(request):
55 return request.GET.get('ignorews', '') == '1'
55 return request.GET.get('ignorews', '') == '1'
56
56
57
57
58 class OPS(object):
58 class OPS(object):
59 ADD = 'A'
59 ADD = 'A'
60 MOD = 'M'
60 MOD = 'M'
61 DEL = 'D'
61 DEL = 'D'
62
62
63
63
64 def get_gitdiff(filenode_old, filenode_new, ignore_whitespace=True, context=3):
64 def get_gitdiff(filenode_old, filenode_new, ignore_whitespace=True, context=3):
65 """
65 """
66 Returns git style diff between given ``filenode_old`` and ``filenode_new``.
66 Returns git style diff between given ``filenode_old`` and ``filenode_new``.
67
67
68 :param ignore_whitespace: ignore whitespaces in diff
68 :param ignore_whitespace: ignore whitespaces in diff
69 """
69 """
70 # make sure we pass in default context
70 # make sure we pass in default context
71 context = context or 3
71 context = context or 3
72 # protect against IntOverflow when passing HUGE context
72 # protect against IntOverflow when passing HUGE context
73 if context > MAX_CONTEXT:
73 if context > MAX_CONTEXT:
74 context = MAX_CONTEXT
74 context = MAX_CONTEXT
75
75
76 submodules = filter(lambda o: isinstance(o, SubModuleNode),
76 submodules = filter(lambda o: isinstance(o, SubModuleNode),
77 [filenode_new, filenode_old])
77 [filenode_new, filenode_old])
78 if submodules:
78 if submodules:
79 return ''
79 return ''
80
80
81 for filenode in (filenode_old, filenode_new):
81 for filenode in (filenode_old, filenode_new):
82 if not isinstance(filenode, FileNode):
82 if not isinstance(filenode, FileNode):
83 raise VCSError(
83 raise VCSError(
84 "Given object should be FileNode object, not %s"
84 "Given object should be FileNode object, not %s"
85 % filenode.__class__)
85 % filenode.__class__)
86
86
87 repo = filenode_new.commit.repository
87 repo = filenode_new.commit.repository
88 old_commit = filenode_old.commit or repo.EMPTY_COMMIT
88 old_commit = filenode_old.commit or repo.EMPTY_COMMIT
89 new_commit = filenode_new.commit
89 new_commit = filenode_new.commit
90
90
91 vcs_gitdiff = repo.get_diff(
91 vcs_gitdiff = repo.get_diff(
92 old_commit, new_commit, filenode_new.path,
92 old_commit, new_commit, filenode_new.path,
93 ignore_whitespace, context, path1=filenode_old.path)
93 ignore_whitespace, context, path1=filenode_old.path)
94 return vcs_gitdiff
94 return vcs_gitdiff
95
95
96 NEW_FILENODE = 1
96 NEW_FILENODE = 1
97 DEL_FILENODE = 2
97 DEL_FILENODE = 2
98 MOD_FILENODE = 3
98 MOD_FILENODE = 3
99 RENAMED_FILENODE = 4
99 RENAMED_FILENODE = 4
100 COPIED_FILENODE = 5
100 COPIED_FILENODE = 5
101 CHMOD_FILENODE = 6
101 CHMOD_FILENODE = 6
102 BIN_FILENODE = 7
102 BIN_FILENODE = 7
103
103
104
104
105 class LimitedDiffContainer(object):
105 class LimitedDiffContainer(object):
106
106
107 def __init__(self, diff_limit, cur_diff_size, diff):
107 def __init__(self, diff_limit, cur_diff_size, diff):
108 self.diff = diff
108 self.diff = diff
109 self.diff_limit = diff_limit
109 self.diff_limit = diff_limit
110 self.cur_diff_size = cur_diff_size
110 self.cur_diff_size = cur_diff_size
111
111
112 def __getitem__(self, key):
112 def __getitem__(self, key):
113 return self.diff.__getitem__(key)
113 return self.diff.__getitem__(key)
114
114
115 def __iter__(self):
115 def __iter__(self):
116 for l in self.diff:
116 for l in self.diff:
117 yield l
117 yield l
118
118
119
119
120 class Action(object):
120 class Action(object):
121 """
121 """
122 Contains constants for the action value of the lines in a parsed diff.
122 Contains constants for the action value of the lines in a parsed diff.
123 """
123 """
124
124
125 ADD = 'add'
125 ADD = 'add'
126 DELETE = 'del'
126 DELETE = 'del'
127 UNMODIFIED = 'unmod'
127 UNMODIFIED = 'unmod'
128
128
129 CONTEXT = 'context'
129 CONTEXT = 'context'
130 OLD_NO_NL = 'old-no-nl'
130 OLD_NO_NL = 'old-no-nl'
131 NEW_NO_NL = 'new-no-nl'
131 NEW_NO_NL = 'new-no-nl'
132
132
133
133
134 class DiffProcessor(object):
134 class DiffProcessor(object):
135 """
135 """
136 Give it a unified or git diff and it returns a list of the files that were
136 Give it a unified or git diff and it returns a list of the files that were
137 mentioned in the diff together with a dict of meta information that
137 mentioned in the diff together with a dict of meta information that
138 can be used to render it in a HTML template.
138 can be used to render it in a HTML template.
139
139
140 .. note:: Unicode handling
140 .. note:: Unicode handling
141
141
142 The original diffs are a byte sequence and can contain filenames
142 The original diffs are a byte sequence and can contain filenames
143 in mixed encodings. This class generally returns `unicode` objects
143 in mixed encodings. This class generally returns `unicode` objects
144 since the result is intended for presentation to the user.
144 since the result is intended for presentation to the user.
145
145
146 """
146 """
147 _chunk_re = re.compile(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(.*)')
147 _chunk_re = re.compile(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(.*)')
148 _newline_marker = re.compile(r'^\\ No newline at end of file')
148 _newline_marker = re.compile(r'^\\ No newline at end of file')
149
149
150 # used for inline highlighter word split
150 # used for inline highlighter word split
151 _token_re = re.compile(r'()(&gt;|&lt;|&amp;|\W+?)')
151 _token_re = re.compile(r'()(&gt;|&lt;|&amp;|\W+?)')
152
152
153 # collapse ranges of commits over given number
153 # collapse ranges of commits over given number
154 _collapse_commits_over = 5
154 _collapse_commits_over = 5
155
155
156 def __init__(self, diff, format='gitdiff', diff_limit=None,
156 def __init__(self, diff, format='gitdiff', diff_limit=None,
157 file_limit=None, show_full_diff=True):
157 file_limit=None, show_full_diff=True):
158 """
158 """
159 :param diff: A `Diff` object representing a diff from a vcs backend
159 :param diff: A `Diff` object representing a diff from a vcs backend
160 :param format: format of diff passed, `udiff` or `gitdiff`
160 :param format: format of diff passed, `udiff` or `gitdiff`
161 :param diff_limit: define the size of diff that is considered "big"
161 :param diff_limit: define the size of diff that is considered "big"
162 based on that parameter cut off will be triggered, set to None
162 based on that parameter cut off will be triggered, set to None
163 to show full diff
163 to show full diff
164 """
164 """
165 self._diff = diff
165 self._diff = diff
166 self._format = format
166 self._format = format
167 self.adds = 0
167 self.adds = 0
168 self.removes = 0
168 self.removes = 0
169 # calculate diff size
169 # calculate diff size
170 self.diff_limit = diff_limit
170 self.diff_limit = diff_limit
171 self.file_limit = file_limit
171 self.file_limit = file_limit
172 self.show_full_diff = show_full_diff
172 self.show_full_diff = show_full_diff
173 self.cur_diff_size = 0
173 self.cur_diff_size = 0
174 self.parsed = False
174 self.parsed = False
175 self.parsed_diff = []
175 self.parsed_diff = []
176
176
177 log.debug('Initialized DiffProcessor with %s mode', format)
177 log.debug('Initialized DiffProcessor with %s mode', format)
178 if format == 'gitdiff':
178 if format == 'gitdiff':
179 self.differ = self._highlight_line_difflib
179 self.differ = self._highlight_line_difflib
180 self._parser = self._parse_gitdiff
180 self._parser = self._parse_gitdiff
181 else:
181 else:
182 self.differ = self._highlight_line_udiff
182 self.differ = self._highlight_line_udiff
183 self._parser = self._new_parse_gitdiff
183 self._parser = self._new_parse_gitdiff
184
184
185 def _copy_iterator(self):
185 def _copy_iterator(self):
186 """
186 """
187 make a fresh copy of generator, we should not iterate thru
187 make a fresh copy of generator, we should not iterate thru
188 an original as it's needed for repeating operations on
188 an original as it's needed for repeating operations on
189 this instance of DiffProcessor
189 this instance of DiffProcessor
190 """
190 """
191 self.__udiff, iterator_copy = tee(self.__udiff)
191 self.__udiff, iterator_copy = tee(self.__udiff)
192 return iterator_copy
192 return iterator_copy
193
193
194 def _escaper(self, string):
194 def _escaper(self, string):
195 """
195 """
196 Escaper for diff escapes special chars and checks the diff limit
196 Escaper for diff escapes special chars and checks the diff limit
197
197
198 :param string:
198 :param string:
199 """
199 """
200 self.cur_diff_size += len(string)
200 self.cur_diff_size += len(string)
201
201
202 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
202 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
203 raise DiffLimitExceeded('Diff Limit Exceeded')
203 raise DiffLimitExceeded('Diff Limit Exceeded')
204
204
205 return string \
205 return string \
206 .replace('&', '&amp;')\
206 .replace('&', '&amp;')\
207 .replace('<', '&lt;')\
207 .replace('<', '&lt;')\
208 .replace('>', '&gt;')
208 .replace('>', '&gt;')
209
209
210 def _line_counter(self, l):
210 def _line_counter(self, l):
211 """
211 """
212 Checks each line and bumps total adds/removes for this diff
212 Checks each line and bumps total adds/removes for this diff
213
213
214 :param l:
214 :param l:
215 """
215 """
216 if l.startswith('+') and not l.startswith('+++'):
216 if l.startswith('+') and not l.startswith('+++'):
217 self.adds += 1
217 self.adds += 1
218 elif l.startswith('-') and not l.startswith('---'):
218 elif l.startswith('-') and not l.startswith('---'):
219 self.removes += 1
219 self.removes += 1
220 return safe_unicode(l)
220 return safe_unicode(l)
221
221
222 def _highlight_line_difflib(self, line, next_):
222 def _highlight_line_difflib(self, line, next_):
223 """
223 """
224 Highlight inline changes in both lines.
224 Highlight inline changes in both lines.
225 """
225 """
226
226
227 if line['action'] == Action.DELETE:
227 if line['action'] == Action.DELETE:
228 old, new = line, next_
228 old, new = line, next_
229 else:
229 else:
230 old, new = next_, line
230 old, new = next_, line
231
231
232 oldwords = self._token_re.split(old['line'])
232 oldwords = self._token_re.split(old['line'])
233 newwords = self._token_re.split(new['line'])
233 newwords = self._token_re.split(new['line'])
234 sequence = difflib.SequenceMatcher(None, oldwords, newwords)
234 sequence = difflib.SequenceMatcher(None, oldwords, newwords)
235
235
236 oldfragments, newfragments = [], []
236 oldfragments, newfragments = [], []
237 for tag, i1, i2, j1, j2 in sequence.get_opcodes():
237 for tag, i1, i2, j1, j2 in sequence.get_opcodes():
238 oldfrag = ''.join(oldwords[i1:i2])
238 oldfrag = ''.join(oldwords[i1:i2])
239 newfrag = ''.join(newwords[j1:j2])
239 newfrag = ''.join(newwords[j1:j2])
240 if tag != 'equal':
240 if tag != 'equal':
241 if oldfrag:
241 if oldfrag:
242 oldfrag = '<del>%s</del>' % oldfrag
242 oldfrag = '<del>%s</del>' % oldfrag
243 if newfrag:
243 if newfrag:
244 newfrag = '<ins>%s</ins>' % newfrag
244 newfrag = '<ins>%s</ins>' % newfrag
245 oldfragments.append(oldfrag)
245 oldfragments.append(oldfrag)
246 newfragments.append(newfrag)
246 newfragments.append(newfrag)
247
247
248 old['line'] = "".join(oldfragments)
248 old['line'] = "".join(oldfragments)
249 new['line'] = "".join(newfragments)
249 new['line'] = "".join(newfragments)
250
250
251 def _highlight_line_udiff(self, line, next_):
251 def _highlight_line_udiff(self, line, next_):
252 """
252 """
253 Highlight inline changes in both lines.
253 Highlight inline changes in both lines.
254 """
254 """
255 start = 0
255 start = 0
256 limit = min(len(line['line']), len(next_['line']))
256 limit = min(len(line['line']), len(next_['line']))
257 while start < limit and line['line'][start] == next_['line'][start]:
257 while start < limit and line['line'][start] == next_['line'][start]:
258 start += 1
258 start += 1
259 end = -1
259 end = -1
260 limit -= start
260 limit -= start
261 while -end <= limit and line['line'][end] == next_['line'][end]:
261 while -end <= limit and line['line'][end] == next_['line'][end]:
262 end -= 1
262 end -= 1
263 end += 1
263 end += 1
264 if start or end:
264 if start or end:
265 def do(l):
265 def do(l):
266 last = end + len(l['line'])
266 last = end + len(l['line'])
267 if l['action'] == Action.ADD:
267 if l['action'] == Action.ADD:
268 tag = 'ins'
268 tag = 'ins'
269 else:
269 else:
270 tag = 'del'
270 tag = 'del'
271 l['line'] = '%s<%s>%s</%s>%s' % (
271 l['line'] = '%s<%s>%s</%s>%s' % (
272 l['line'][:start],
272 l['line'][:start],
273 tag,
273 tag,
274 l['line'][start:last],
274 l['line'][start:last],
275 tag,
275 tag,
276 l['line'][last:]
276 l['line'][last:]
277 )
277 )
278 do(line)
278 do(line)
279 do(next_)
279 do(next_)
280
280
281 def _clean_line(self, line, command):
281 def _clean_line(self, line, command):
282 if command in ['+', '-', ' ']:
282 if command in ['+', '-', ' ']:
283 # only modify the line if it's actually a diff thing
283 # only modify the line if it's actually a diff thing
284 line = line[1:]
284 line = line[1:]
285 return line
285 return line
286
286
287 def _parse_gitdiff(self, inline_diff=True):
287 def _parse_gitdiff(self, inline_diff=True):
288 _files = []
288 _files = []
289 diff_container = lambda arg: arg
289 diff_container = lambda arg: arg
290
290
291 for chunk in self._diff.chunks():
291 for chunk in self._diff.chunks():
292 head = chunk.header
292 head = chunk.header
293
293
294 diff = imap(self._escaper, self.diff_splitter(chunk.diff))
294 diff = imap(self._escaper, self.diff_splitter(chunk.diff))
295 raw_diff = chunk.raw
295 raw_diff = chunk.raw
296 limited_diff = False
296 limited_diff = False
297 exceeds_limit = False
297 exceeds_limit = False
298
298
299 op = None
299 op = None
300 stats = {
300 stats = {
301 'added': 0,
301 'added': 0,
302 'deleted': 0,
302 'deleted': 0,
303 'binary': False,
303 'binary': False,
304 'ops': {},
304 'ops': {},
305 }
305 }
306
306
307 if head['deleted_file_mode']:
307 if head['deleted_file_mode']:
308 op = OPS.DEL
308 op = OPS.DEL
309 stats['binary'] = True
309 stats['binary'] = True
310 stats['ops'][DEL_FILENODE] = 'deleted file'
310 stats['ops'][DEL_FILENODE] = 'deleted file'
311
311
312 elif head['new_file_mode']:
312 elif head['new_file_mode']:
313 op = OPS.ADD
313 op = OPS.ADD
314 stats['binary'] = True
314 stats['binary'] = True
315 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
315 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
316 else: # modify operation, can be copy, rename or chmod
316 else: # modify operation, can be copy, rename or chmod
317
317
318 # CHMOD
318 # CHMOD
319 if head['new_mode'] and head['old_mode']:
319 if head['new_mode'] and head['old_mode']:
320 op = OPS.MOD
320 op = OPS.MOD
321 stats['binary'] = True
321 stats['binary'] = True
322 stats['ops'][CHMOD_FILENODE] = (
322 stats['ops'][CHMOD_FILENODE] = (
323 'modified file chmod %s => %s' % (
323 'modified file chmod %s => %s' % (
324 head['old_mode'], head['new_mode']))
324 head['old_mode'], head['new_mode']))
325 # RENAME
325 # RENAME
326 if head['rename_from'] != head['rename_to']:
326 if head['rename_from'] != head['rename_to']:
327 op = OPS.MOD
327 op = OPS.MOD
328 stats['binary'] = True
328 stats['binary'] = True
329 stats['ops'][RENAMED_FILENODE] = (
329 stats['ops'][RENAMED_FILENODE] = (
330 'file renamed from %s to %s' % (
330 'file renamed from %s to %s' % (
331 head['rename_from'], head['rename_to']))
331 head['rename_from'], head['rename_to']))
332 # COPY
332 # COPY
333 if head.get('copy_from') and head.get('copy_to'):
333 if head.get('copy_from') and head.get('copy_to'):
334 op = OPS.MOD
334 op = OPS.MOD
335 stats['binary'] = True
335 stats['binary'] = True
336 stats['ops'][COPIED_FILENODE] = (
336 stats['ops'][COPIED_FILENODE] = (
337 'file copied from %s to %s' % (
337 'file copied from %s to %s' % (
338 head['copy_from'], head['copy_to']))
338 head['copy_from'], head['copy_to']))
339
339
340 # If our new parsed headers didn't match anything fallback to
340 # If our new parsed headers didn't match anything fallback to
341 # old style detection
341 # old style detection
342 if op is None:
342 if op is None:
343 if not head['a_file'] and head['b_file']:
343 if not head['a_file'] and head['b_file']:
344 op = OPS.ADD
344 op = OPS.ADD
345 stats['binary'] = True
345 stats['binary'] = True
346 stats['ops'][NEW_FILENODE] = 'new file'
346 stats['ops'][NEW_FILENODE] = 'new file'
347
347
348 elif head['a_file'] and not head['b_file']:
348 elif head['a_file'] and not head['b_file']:
349 op = OPS.DEL
349 op = OPS.DEL
350 stats['binary'] = True
350 stats['binary'] = True
351 stats['ops'][DEL_FILENODE] = 'deleted file'
351 stats['ops'][DEL_FILENODE] = 'deleted file'
352
352
353 # it's not ADD not DELETE
353 # it's not ADD not DELETE
354 if op is None:
354 if op is None:
355 op = OPS.MOD
355 op = OPS.MOD
356 stats['binary'] = True
356 stats['binary'] = True
357 stats['ops'][MOD_FILENODE] = 'modified file'
357 stats['ops'][MOD_FILENODE] = 'modified file'
358
358
359 # a real non-binary diff
359 # a real non-binary diff
360 if head['a_file'] or head['b_file']:
360 if head['a_file'] or head['b_file']:
361 try:
361 try:
362 raw_diff, chunks, _stats = self._parse_lines(diff)
362 raw_diff, chunks, _stats = self._parse_lines(diff)
363 stats['binary'] = False
363 stats['binary'] = False
364 stats['added'] = _stats[0]
364 stats['added'] = _stats[0]
365 stats['deleted'] = _stats[1]
365 stats['deleted'] = _stats[1]
366 # explicit mark that it's a modified file
366 # explicit mark that it's a modified file
367 if op == OPS.MOD:
367 if op == OPS.MOD:
368 stats['ops'][MOD_FILENODE] = 'modified file'
368 stats['ops'][MOD_FILENODE] = 'modified file'
369 exceeds_limit = len(raw_diff) > self.file_limit
369 exceeds_limit = len(raw_diff) > self.file_limit
370
370
371 # changed from _escaper function so we validate size of
371 # changed from _escaper function so we validate size of
372 # each file instead of the whole diff
372 # each file instead of the whole diff
373 # diff will hide big files but still show small ones
373 # diff will hide big files but still show small ones
374 # from my tests, big files are fairly safe to be parsed
374 # from my tests, big files are fairly safe to be parsed
375 # but the browser is the bottleneck
375 # but the browser is the bottleneck
376 if not self.show_full_diff and exceeds_limit:
376 if not self.show_full_diff and exceeds_limit:
377 raise DiffLimitExceeded('File Limit Exceeded')
377 raise DiffLimitExceeded('File Limit Exceeded')
378
378
379 except DiffLimitExceeded:
379 except DiffLimitExceeded:
380 diff_container = lambda _diff: \
380 diff_container = lambda _diff: \
381 LimitedDiffContainer(
381 LimitedDiffContainer(
382 self.diff_limit, self.cur_diff_size, _diff)
382 self.diff_limit, self.cur_diff_size, _diff)
383
383
384 exceeds_limit = len(raw_diff) > self.file_limit
384 exceeds_limit = len(raw_diff) > self.file_limit
385 limited_diff = True
385 limited_diff = True
386 chunks = []
386 chunks = []
387
387
388 else: # GIT format binary patch, or possibly empty diff
388 else: # GIT format binary patch, or possibly empty diff
389 if head['bin_patch']:
389 if head['bin_patch']:
390 # we have operation already extracted, but we mark simply
390 # we have operation already extracted, but we mark simply
391 # it's a diff we wont show for binary files
391 # it's a diff we wont show for binary files
392 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
392 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
393 chunks = []
393 chunks = []
394
394
395 if chunks and not self.show_full_diff and op == OPS.DEL:
395 if chunks and not self.show_full_diff and op == OPS.DEL:
396 # if not full diff mode show deleted file contents
396 # if not full diff mode show deleted file contents
397 # TODO: anderson: if the view is not too big, there is no way
397 # TODO: anderson: if the view is not too big, there is no way
398 # to see the content of the file
398 # to see the content of the file
399 chunks = []
399 chunks = []
400
400
401 chunks.insert(0, [{
401 chunks.insert(0, [{
402 'old_lineno': '',
402 'old_lineno': '',
403 'new_lineno': '',
403 'new_lineno': '',
404 'action': Action.CONTEXT,
404 'action': Action.CONTEXT,
405 'line': msg,
405 'line': msg,
406 } for _op, msg in stats['ops'].iteritems()
406 } for _op, msg in stats['ops'].iteritems()
407 if _op not in [MOD_FILENODE]])
407 if _op not in [MOD_FILENODE]])
408
408
409 _files.append({
409 _files.append({
410 'filename': safe_unicode(head['b_path']),
410 'filename': safe_unicode(head['b_path']),
411 'old_revision': head['a_blob_id'],
411 'old_revision': head['a_blob_id'],
412 'new_revision': head['b_blob_id'],
412 'new_revision': head['b_blob_id'],
413 'chunks': chunks,
413 'chunks': chunks,
414 'raw_diff': safe_unicode(raw_diff),
414 'raw_diff': safe_unicode(raw_diff),
415 'operation': op,
415 'operation': op,
416 'stats': stats,
416 'stats': stats,
417 'exceeds_limit': exceeds_limit,
417 'exceeds_limit': exceeds_limit,
418 'is_limited_diff': limited_diff,
418 'is_limited_diff': limited_diff,
419 })
419 })
420
420
421 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
421 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
422 OPS.DEL: 2}.get(info['operation'])
422 OPS.DEL: 2}.get(info['operation'])
423
423
424 if not inline_diff:
424 if not inline_diff:
425 return diff_container(sorted(_files, key=sorter))
425 return diff_container(sorted(_files, key=sorter))
426
426
427 # highlight inline changes
427 # highlight inline changes
428 for diff_data in _files:
428 for diff_data in _files:
429 for chunk in diff_data['chunks']:
429 for chunk in diff_data['chunks']:
430 lineiter = iter(chunk)
430 lineiter = iter(chunk)
431 try:
431 try:
432 while 1:
432 while 1:
433 line = lineiter.next()
433 line = next(lineiter)
434 if line['action'] not in (
434 if line['action'] not in (
435 Action.UNMODIFIED, Action.CONTEXT):
435 Action.UNMODIFIED, Action.CONTEXT):
436 nextline = lineiter.next()
436 nextline = next(lineiter)
437 if nextline['action'] in ['unmod', 'context'] or \
437 if nextline['action'] in ['unmod', 'context'] or \
438 nextline['action'] == line['action']:
438 nextline['action'] == line['action']:
439 continue
439 continue
440 self.differ(line, nextline)
440 self.differ(line, nextline)
441 except StopIteration:
441 except StopIteration:
442 pass
442 pass
443
443
444 return diff_container(sorted(_files, key=sorter))
444 return diff_container(sorted(_files, key=sorter))
445
445
446 def _check_large_diff(self):
446 def _check_large_diff(self):
447 if self.diff_limit:
447 if self.diff_limit:
448 log.debug('Checking if diff exceeds current diff_limit of %s', self.diff_limit)
448 log.debug('Checking if diff exceeds current diff_limit of %s', self.diff_limit)
449 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
449 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
450 raise DiffLimitExceeded('Diff Limit `%s` Exceeded', self.diff_limit)
450 raise DiffLimitExceeded('Diff Limit `%s` Exceeded', self.diff_limit)
451
451
452 # FIXME: NEWDIFFS: dan: this replaces _parse_gitdiff
452 # FIXME: NEWDIFFS: dan: this replaces _parse_gitdiff
453 def _new_parse_gitdiff(self, inline_diff=True):
453 def _new_parse_gitdiff(self, inline_diff=True):
454 _files = []
454 _files = []
455
455
456 # this can be overriden later to a LimitedDiffContainer type
456 # this can be overriden later to a LimitedDiffContainer type
457 diff_container = lambda arg: arg
457 diff_container = lambda arg: arg
458
458
459 for chunk in self._diff.chunks():
459 for chunk in self._diff.chunks():
460 head = chunk.header
460 head = chunk.header
461 log.debug('parsing diff %r', head)
461 log.debug('parsing diff %r', head)
462
462
463 raw_diff = chunk.raw
463 raw_diff = chunk.raw
464 limited_diff = False
464 limited_diff = False
465 exceeds_limit = False
465 exceeds_limit = False
466
466
467 op = None
467 op = None
468 stats = {
468 stats = {
469 'added': 0,
469 'added': 0,
470 'deleted': 0,
470 'deleted': 0,
471 'binary': False,
471 'binary': False,
472 'old_mode': None,
472 'old_mode': None,
473 'new_mode': None,
473 'new_mode': None,
474 'ops': {},
474 'ops': {},
475 }
475 }
476 if head['old_mode']:
476 if head['old_mode']:
477 stats['old_mode'] = head['old_mode']
477 stats['old_mode'] = head['old_mode']
478 if head['new_mode']:
478 if head['new_mode']:
479 stats['new_mode'] = head['new_mode']
479 stats['new_mode'] = head['new_mode']
480 if head['b_mode']:
480 if head['b_mode']:
481 stats['new_mode'] = head['b_mode']
481 stats['new_mode'] = head['b_mode']
482
482
483 # delete file
483 # delete file
484 if head['deleted_file_mode']:
484 if head['deleted_file_mode']:
485 op = OPS.DEL
485 op = OPS.DEL
486 stats['binary'] = True
486 stats['binary'] = True
487 stats['ops'][DEL_FILENODE] = 'deleted file'
487 stats['ops'][DEL_FILENODE] = 'deleted file'
488
488
489 # new file
489 # new file
490 elif head['new_file_mode']:
490 elif head['new_file_mode']:
491 op = OPS.ADD
491 op = OPS.ADD
492 stats['binary'] = True
492 stats['binary'] = True
493 stats['old_mode'] = None
493 stats['old_mode'] = None
494 stats['new_mode'] = head['new_file_mode']
494 stats['new_mode'] = head['new_file_mode']
495 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
495 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
496
496
497 # modify operation, can be copy, rename or chmod
497 # modify operation, can be copy, rename or chmod
498 else:
498 else:
499 # CHMOD
499 # CHMOD
500 if head['new_mode'] and head['old_mode']:
500 if head['new_mode'] and head['old_mode']:
501 op = OPS.MOD
501 op = OPS.MOD
502 stats['binary'] = True
502 stats['binary'] = True
503 stats['ops'][CHMOD_FILENODE] = (
503 stats['ops'][CHMOD_FILENODE] = (
504 'modified file chmod %s => %s' % (
504 'modified file chmod %s => %s' % (
505 head['old_mode'], head['new_mode']))
505 head['old_mode'], head['new_mode']))
506
506
507 # RENAME
507 # RENAME
508 if head['rename_from'] != head['rename_to']:
508 if head['rename_from'] != head['rename_to']:
509 op = OPS.MOD
509 op = OPS.MOD
510 stats['binary'] = True
510 stats['binary'] = True
511 stats['renamed'] = (head['rename_from'], head['rename_to'])
511 stats['renamed'] = (head['rename_from'], head['rename_to'])
512 stats['ops'][RENAMED_FILENODE] = (
512 stats['ops'][RENAMED_FILENODE] = (
513 'file renamed from %s to %s' % (
513 'file renamed from %s to %s' % (
514 head['rename_from'], head['rename_to']))
514 head['rename_from'], head['rename_to']))
515 # COPY
515 # COPY
516 if head.get('copy_from') and head.get('copy_to'):
516 if head.get('copy_from') and head.get('copy_to'):
517 op = OPS.MOD
517 op = OPS.MOD
518 stats['binary'] = True
518 stats['binary'] = True
519 stats['copied'] = (head['copy_from'], head['copy_to'])
519 stats['copied'] = (head['copy_from'], head['copy_to'])
520 stats['ops'][COPIED_FILENODE] = (
520 stats['ops'][COPIED_FILENODE] = (
521 'file copied from %s to %s' % (
521 'file copied from %s to %s' % (
522 head['copy_from'], head['copy_to']))
522 head['copy_from'], head['copy_to']))
523
523
524 # If our new parsed headers didn't match anything fallback to
524 # If our new parsed headers didn't match anything fallback to
525 # old style detection
525 # old style detection
526 if op is None:
526 if op is None:
527 if not head['a_file'] and head['b_file']:
527 if not head['a_file'] and head['b_file']:
528 op = OPS.ADD
528 op = OPS.ADD
529 stats['binary'] = True
529 stats['binary'] = True
530 stats['new_file'] = True
530 stats['new_file'] = True
531 stats['ops'][NEW_FILENODE] = 'new file'
531 stats['ops'][NEW_FILENODE] = 'new file'
532
532
533 elif head['a_file'] and not head['b_file']:
533 elif head['a_file'] and not head['b_file']:
534 op = OPS.DEL
534 op = OPS.DEL
535 stats['binary'] = True
535 stats['binary'] = True
536 stats['ops'][DEL_FILENODE] = 'deleted file'
536 stats['ops'][DEL_FILENODE] = 'deleted file'
537
537
538 # it's not ADD not DELETE
538 # it's not ADD not DELETE
539 if op is None:
539 if op is None:
540 op = OPS.MOD
540 op = OPS.MOD
541 stats['binary'] = True
541 stats['binary'] = True
542 stats['ops'][MOD_FILENODE] = 'modified file'
542 stats['ops'][MOD_FILENODE] = 'modified file'
543
543
544 # a real non-binary diff
544 # a real non-binary diff
545 if head['a_file'] or head['b_file']:
545 if head['a_file'] or head['b_file']:
546 # simulate splitlines, so we keep the line end part
546 # simulate splitlines, so we keep the line end part
547 diff = self.diff_splitter(chunk.diff)
547 diff = self.diff_splitter(chunk.diff)
548
548
549 # append each file to the diff size
549 # append each file to the diff size
550 raw_chunk_size = len(raw_diff)
550 raw_chunk_size = len(raw_diff)
551
551
552 exceeds_limit = raw_chunk_size > self.file_limit
552 exceeds_limit = raw_chunk_size > self.file_limit
553 self.cur_diff_size += raw_chunk_size
553 self.cur_diff_size += raw_chunk_size
554
554
555 try:
555 try:
556 # Check each file instead of the whole diff.
556 # Check each file instead of the whole diff.
557 # Diff will hide big files but still show small ones.
557 # Diff will hide big files but still show small ones.
558 # From the tests big files are fairly safe to be parsed
558 # From the tests big files are fairly safe to be parsed
559 # but the browser is the bottleneck.
559 # but the browser is the bottleneck.
560 if not self.show_full_diff and exceeds_limit:
560 if not self.show_full_diff and exceeds_limit:
561 log.debug('File `%s` exceeds current file_limit of %s',
561 log.debug('File `%s` exceeds current file_limit of %s',
562 safe_unicode(head['b_path']), self.file_limit)
562 safe_unicode(head['b_path']), self.file_limit)
563 raise DiffLimitExceeded(
563 raise DiffLimitExceeded(
564 'File Limit %s Exceeded', self.file_limit)
564 'File Limit %s Exceeded', self.file_limit)
565
565
566 self._check_large_diff()
566 self._check_large_diff()
567
567
568 raw_diff, chunks, _stats = self._new_parse_lines(diff)
568 raw_diff, chunks, _stats = self._new_parse_lines(diff)
569 stats['binary'] = False
569 stats['binary'] = False
570 stats['added'] = _stats[0]
570 stats['added'] = _stats[0]
571 stats['deleted'] = _stats[1]
571 stats['deleted'] = _stats[1]
572 # explicit mark that it's a modified file
572 # explicit mark that it's a modified file
573 if op == OPS.MOD:
573 if op == OPS.MOD:
574 stats['ops'][MOD_FILENODE] = 'modified file'
574 stats['ops'][MOD_FILENODE] = 'modified file'
575
575
576 except DiffLimitExceeded:
576 except DiffLimitExceeded:
577 diff_container = lambda _diff: \
577 diff_container = lambda _diff: \
578 LimitedDiffContainer(
578 LimitedDiffContainer(
579 self.diff_limit, self.cur_diff_size, _diff)
579 self.diff_limit, self.cur_diff_size, _diff)
580
580
581 limited_diff = True
581 limited_diff = True
582 chunks = []
582 chunks = []
583
583
584 else: # GIT format binary patch, or possibly empty diff
584 else: # GIT format binary patch, or possibly empty diff
585 if head['bin_patch']:
585 if head['bin_patch']:
586 # we have operation already extracted, but we mark simply
586 # we have operation already extracted, but we mark simply
587 # it's a diff we wont show for binary files
587 # it's a diff we wont show for binary files
588 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
588 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
589 chunks = []
589 chunks = []
590
590
591 # Hide content of deleted node by setting empty chunks
591 # Hide content of deleted node by setting empty chunks
592 if chunks and not self.show_full_diff and op == OPS.DEL:
592 if chunks and not self.show_full_diff and op == OPS.DEL:
593 # if not full diff mode show deleted file contents
593 # if not full diff mode show deleted file contents
594 # TODO: anderson: if the view is not too big, there is no way
594 # TODO: anderson: if the view is not too big, there is no way
595 # to see the content of the file
595 # to see the content of the file
596 chunks = []
596 chunks = []
597
597
598 chunks.insert(
598 chunks.insert(
599 0, [{'old_lineno': '',
599 0, [{'old_lineno': '',
600 'new_lineno': '',
600 'new_lineno': '',
601 'action': Action.CONTEXT,
601 'action': Action.CONTEXT,
602 'line': msg,
602 'line': msg,
603 } for _op, msg in stats['ops'].iteritems()
603 } for _op, msg in stats['ops'].iteritems()
604 if _op not in [MOD_FILENODE]])
604 if _op not in [MOD_FILENODE]])
605
605
606 original_filename = safe_unicode(head['a_path'])
606 original_filename = safe_unicode(head['a_path'])
607 _files.append({
607 _files.append({
608 'original_filename': original_filename,
608 'original_filename': original_filename,
609 'filename': safe_unicode(head['b_path']),
609 'filename': safe_unicode(head['b_path']),
610 'old_revision': head['a_blob_id'],
610 'old_revision': head['a_blob_id'],
611 'new_revision': head['b_blob_id'],
611 'new_revision': head['b_blob_id'],
612 'chunks': chunks,
612 'chunks': chunks,
613 'raw_diff': safe_unicode(raw_diff),
613 'raw_diff': safe_unicode(raw_diff),
614 'operation': op,
614 'operation': op,
615 'stats': stats,
615 'stats': stats,
616 'exceeds_limit': exceeds_limit,
616 'exceeds_limit': exceeds_limit,
617 'is_limited_diff': limited_diff,
617 'is_limited_diff': limited_diff,
618 })
618 })
619
619
620 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
620 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
621 OPS.DEL: 2}.get(info['operation'])
621 OPS.DEL: 2}.get(info['operation'])
622
622
623 return diff_container(sorted(_files, key=sorter))
623 return diff_container(sorted(_files, key=sorter))
624
624
625 # FIXME: NEWDIFFS: dan: this gets replaced by _new_parse_lines
625 # FIXME: NEWDIFFS: dan: this gets replaced by _new_parse_lines
626 def _parse_lines(self, diff_iter):
626 def _parse_lines(self, diff_iter):
627 """
627 """
628 Parse the diff an return data for the template.
628 Parse the diff an return data for the template.
629 """
629 """
630
630
631 stats = [0, 0]
631 stats = [0, 0]
632 chunks = []
632 chunks = []
633 raw_diff = []
633 raw_diff = []
634
634
635 try:
635 try:
636 line = diff_iter.next()
636 line = next(diff_iter)
637
637
638 while line:
638 while line:
639 raw_diff.append(line)
639 raw_diff.append(line)
640 lines = []
640 lines = []
641 chunks.append(lines)
641 chunks.append(lines)
642
642
643 match = self._chunk_re.match(line)
643 match = self._chunk_re.match(line)
644
644
645 if not match:
645 if not match:
646 break
646 break
647
647
648 gr = match.groups()
648 gr = match.groups()
649 (old_line, old_end,
649 (old_line, old_end,
650 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
650 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
651 old_line -= 1
651 old_line -= 1
652 new_line -= 1
652 new_line -= 1
653
653
654 context = len(gr) == 5
654 context = len(gr) == 5
655 old_end += old_line
655 old_end += old_line
656 new_end += new_line
656 new_end += new_line
657
657
658 if context:
658 if context:
659 # skip context only if it's first line
659 # skip context only if it's first line
660 if int(gr[0]) > 1:
660 if int(gr[0]) > 1:
661 lines.append({
661 lines.append({
662 'old_lineno': '...',
662 'old_lineno': '...',
663 'new_lineno': '...',
663 'new_lineno': '...',
664 'action': Action.CONTEXT,
664 'action': Action.CONTEXT,
665 'line': line,
665 'line': line,
666 })
666 })
667
667
668 line = diff_iter.next()
668 line = next(diff_iter)
669
669
670 while old_line < old_end or new_line < new_end:
670 while old_line < old_end or new_line < new_end:
671 command = ' '
671 command = ' '
672 if line:
672 if line:
673 command = line[0]
673 command = line[0]
674
674
675 affects_old = affects_new = False
675 affects_old = affects_new = False
676
676
677 # ignore those if we don't expect them
677 # ignore those if we don't expect them
678 if command in '#@':
678 if command in '#@':
679 continue
679 continue
680 elif command == '+':
680 elif command == '+':
681 affects_new = True
681 affects_new = True
682 action = Action.ADD
682 action = Action.ADD
683 stats[0] += 1
683 stats[0] += 1
684 elif command == '-':
684 elif command == '-':
685 affects_old = True
685 affects_old = True
686 action = Action.DELETE
686 action = Action.DELETE
687 stats[1] += 1
687 stats[1] += 1
688 else:
688 else:
689 affects_old = affects_new = True
689 affects_old = affects_new = True
690 action = Action.UNMODIFIED
690 action = Action.UNMODIFIED
691
691
692 if not self._newline_marker.match(line):
692 if not self._newline_marker.match(line):
693 old_line += affects_old
693 old_line += affects_old
694 new_line += affects_new
694 new_line += affects_new
695 lines.append({
695 lines.append({
696 'old_lineno': affects_old and old_line or '',
696 'old_lineno': affects_old and old_line or '',
697 'new_lineno': affects_new and new_line or '',
697 'new_lineno': affects_new and new_line or '',
698 'action': action,
698 'action': action,
699 'line': self._clean_line(line, command)
699 'line': self._clean_line(line, command)
700 })
700 })
701 raw_diff.append(line)
701 raw_diff.append(line)
702
702
703 line = diff_iter.next()
703 line = next(diff_iter)
704
704
705 if self._newline_marker.match(line):
705 if self._newline_marker.match(line):
706 # we need to append to lines, since this is not
706 # we need to append to lines, since this is not
707 # counted in the line specs of diff
707 # counted in the line specs of diff
708 lines.append({
708 lines.append({
709 'old_lineno': '...',
709 'old_lineno': '...',
710 'new_lineno': '...',
710 'new_lineno': '...',
711 'action': Action.CONTEXT,
711 'action': Action.CONTEXT,
712 'line': self._clean_line(line, command)
712 'line': self._clean_line(line, command)
713 })
713 })
714
714
715 except StopIteration:
715 except StopIteration:
716 pass
716 pass
717 return ''.join(raw_diff), chunks, stats
717 return ''.join(raw_diff), chunks, stats
718
718
719 # FIXME: NEWDIFFS: dan: this replaces _parse_lines
719 # FIXME: NEWDIFFS: dan: this replaces _parse_lines
720 def _new_parse_lines(self, diff_iter):
720 def _new_parse_lines(self, diff_iter):
721 """
721 """
722 Parse the diff an return data for the template.
722 Parse the diff an return data for the template.
723 """
723 """
724
724
725 stats = [0, 0]
725 stats = [0, 0]
726 chunks = []
726 chunks = []
727 raw_diff = []
727 raw_diff = []
728
728
729 try:
729 try:
730 line = diff_iter.next()
730 line = next(diff_iter)
731
731
732 while line:
732 while line:
733 raw_diff.append(line)
733 raw_diff.append(line)
734 # match header e.g @@ -0,0 +1 @@\n'
734 # match header e.g @@ -0,0 +1 @@\n'
735 match = self._chunk_re.match(line)
735 match = self._chunk_re.match(line)
736
736
737 if not match:
737 if not match:
738 break
738 break
739
739
740 gr = match.groups()
740 gr = match.groups()
741 (old_line, old_end,
741 (old_line, old_end,
742 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
742 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
743
743
744 lines = []
744 lines = []
745 hunk = {
745 hunk = {
746 'section_header': gr[-1],
746 'section_header': gr[-1],
747 'source_start': old_line,
747 'source_start': old_line,
748 'source_length': old_end,
748 'source_length': old_end,
749 'target_start': new_line,
749 'target_start': new_line,
750 'target_length': new_end,
750 'target_length': new_end,
751 'lines': lines,
751 'lines': lines,
752 }
752 }
753 chunks.append(hunk)
753 chunks.append(hunk)
754
754
755 old_line -= 1
755 old_line -= 1
756 new_line -= 1
756 new_line -= 1
757
757
758 context = len(gr) == 5
758 context = len(gr) == 5
759 old_end += old_line
759 old_end += old_line
760 new_end += new_line
760 new_end += new_line
761
761
762 line = diff_iter.next()
762 line = next(diff_iter)
763
763
764 while old_line < old_end or new_line < new_end:
764 while old_line < old_end or new_line < new_end:
765 command = ' '
765 command = ' '
766 if line:
766 if line:
767 command = line[0]
767 command = line[0]
768
768
769 affects_old = affects_new = False
769 affects_old = affects_new = False
770
770
771 # ignore those if we don't expect them
771 # ignore those if we don't expect them
772 if command in '#@':
772 if command in '#@':
773 continue
773 continue
774 elif command == '+':
774 elif command == '+':
775 affects_new = True
775 affects_new = True
776 action = Action.ADD
776 action = Action.ADD
777 stats[0] += 1
777 stats[0] += 1
778 elif command == '-':
778 elif command == '-':
779 affects_old = True
779 affects_old = True
780 action = Action.DELETE
780 action = Action.DELETE
781 stats[1] += 1
781 stats[1] += 1
782 else:
782 else:
783 affects_old = affects_new = True
783 affects_old = affects_new = True
784 action = Action.UNMODIFIED
784 action = Action.UNMODIFIED
785
785
786 if not self._newline_marker.match(line):
786 if not self._newline_marker.match(line):
787 old_line += affects_old
787 old_line += affects_old
788 new_line += affects_new
788 new_line += affects_new
789 lines.append({
789 lines.append({
790 'old_lineno': affects_old and old_line or '',
790 'old_lineno': affects_old and old_line or '',
791 'new_lineno': affects_new and new_line or '',
791 'new_lineno': affects_new and new_line or '',
792 'action': action,
792 'action': action,
793 'line': self._clean_line(line, command)
793 'line': self._clean_line(line, command)
794 })
794 })
795 raw_diff.append(line)
795 raw_diff.append(line)
796
796
797 line = diff_iter.next()
797 line = next(diff_iter)
798
798
799 if self._newline_marker.match(line):
799 if self._newline_marker.match(line):
800 # we need to append to lines, since this is not
800 # we need to append to lines, since this is not
801 # counted in the line specs of diff
801 # counted in the line specs of diff
802 if affects_old:
802 if affects_old:
803 action = Action.OLD_NO_NL
803 action = Action.OLD_NO_NL
804 elif affects_new:
804 elif affects_new:
805 action = Action.NEW_NO_NL
805 action = Action.NEW_NO_NL
806 else:
806 else:
807 raise Exception('invalid context for no newline')
807 raise Exception('invalid context for no newline')
808
808
809 lines.append({
809 lines.append({
810 'old_lineno': None,
810 'old_lineno': None,
811 'new_lineno': None,
811 'new_lineno': None,
812 'action': action,
812 'action': action,
813 'line': self._clean_line(line, command)
813 'line': self._clean_line(line, command)
814 })
814 })
815
815
816 except StopIteration:
816 except StopIteration:
817 pass
817 pass
818
818
819 return ''.join(raw_diff), chunks, stats
819 return ''.join(raw_diff), chunks, stats
820
820
821 def _safe_id(self, idstring):
821 def _safe_id(self, idstring):
822 """Make a string safe for including in an id attribute.
822 """Make a string safe for including in an id attribute.
823
823
824 The HTML spec says that id attributes 'must begin with
824 The HTML spec says that id attributes 'must begin with
825 a letter ([A-Za-z]) and may be followed by any number
825 a letter ([A-Za-z]) and may be followed by any number
826 of letters, digits ([0-9]), hyphens ("-"), underscores
826 of letters, digits ([0-9]), hyphens ("-"), underscores
827 ("_"), colons (":"), and periods (".")'. These regexps
827 ("_"), colons (":"), and periods (".")'. These regexps
828 are slightly over-zealous, in that they remove colons
828 are slightly over-zealous, in that they remove colons
829 and periods unnecessarily.
829 and periods unnecessarily.
830
830
831 Whitespace is transformed into underscores, and then
831 Whitespace is transformed into underscores, and then
832 anything which is not a hyphen or a character that
832 anything which is not a hyphen or a character that
833 matches \w (alphanumerics and underscore) is removed.
833 matches \w (alphanumerics and underscore) is removed.
834
834
835 """
835 """
836 # Transform all whitespace to underscore
836 # Transform all whitespace to underscore
837 idstring = re.sub(r'\s', "_", '%s' % idstring)
837 idstring = re.sub(r'\s', "_", '%s' % idstring)
838 # Remove everything that is not a hyphen or a member of \w
838 # Remove everything that is not a hyphen or a member of \w
839 idstring = re.sub(r'(?!-)\W', "", idstring).lower()
839 idstring = re.sub(r'(?!-)\W', "", idstring).lower()
840 return idstring
840 return idstring
841
841
842 @classmethod
842 @classmethod
843 def diff_splitter(cls, string):
843 def diff_splitter(cls, string):
844 """
844 """
845 Diff split that emulates .splitlines() but works only on \n
845 Diff split that emulates .splitlines() but works only on \n
846 """
846 """
847 if not string:
847 if not string:
848 return
848 return
849 elif string == '\n':
849 elif string == '\n':
850 yield u'\n'
850 yield u'\n'
851 else:
851 else:
852
852
853 has_newline = string.endswith('\n')
853 has_newline = string.endswith('\n')
854 elements = string.split('\n')
854 elements = string.split('\n')
855 if has_newline:
855 if has_newline:
856 # skip last element as it's empty string from newlines
856 # skip last element as it's empty string from newlines
857 elements = elements[:-1]
857 elements = elements[:-1]
858
858
859 len_elements = len(elements)
859 len_elements = len(elements)
860
860
861 for cnt, line in enumerate(elements, start=1):
861 for cnt, line in enumerate(elements, start=1):
862 last_line = cnt == len_elements
862 last_line = cnt == len_elements
863 if last_line and not has_newline:
863 if last_line and not has_newline:
864 yield safe_unicode(line)
864 yield safe_unicode(line)
865 else:
865 else:
866 yield safe_unicode(line) + '\n'
866 yield safe_unicode(line) + '\n'
867
867
868 def prepare(self, inline_diff=True):
868 def prepare(self, inline_diff=True):
869 """
869 """
870 Prepare the passed udiff for HTML rendering.
870 Prepare the passed udiff for HTML rendering.
871
871
872 :return: A list of dicts with diff information.
872 :return: A list of dicts with diff information.
873 """
873 """
874 parsed = self._parser(inline_diff=inline_diff)
874 parsed = self._parser(inline_diff=inline_diff)
875 self.parsed = True
875 self.parsed = True
876 self.parsed_diff = parsed
876 self.parsed_diff = parsed
877 return parsed
877 return parsed
878
878
879 def as_raw(self, diff_lines=None):
879 def as_raw(self, diff_lines=None):
880 """
880 """
881 Returns raw diff as a byte string
881 Returns raw diff as a byte string
882 """
882 """
883 return self._diff.raw
883 return self._diff.raw
884
884
885 def as_html(self, table_class='code-difftable', line_class='line',
885 def as_html(self, table_class='code-difftable', line_class='line',
886 old_lineno_class='lineno old', new_lineno_class='lineno new',
886 old_lineno_class='lineno old', new_lineno_class='lineno new',
887 code_class='code', enable_comments=False, parsed_lines=None):
887 code_class='code', enable_comments=False, parsed_lines=None):
888 """
888 """
889 Return given diff as html table with customized css classes
889 Return given diff as html table with customized css classes
890 """
890 """
891 # TODO(marcink): not sure how to pass in translator
891 # TODO(marcink): not sure how to pass in translator
892 # here in an efficient way, leave the _ for proper gettext extraction
892 # here in an efficient way, leave the _ for proper gettext extraction
893 _ = lambda s: s
893 _ = lambda s: s
894
894
895 def _link_to_if(condition, label, url):
895 def _link_to_if(condition, label, url):
896 """
896 """
897 Generates a link if condition is meet or just the label if not.
897 Generates a link if condition is meet or just the label if not.
898 """
898 """
899
899
900 if condition:
900 if condition:
901 return '''<a href="%(url)s" class="tooltip"
901 return '''<a href="%(url)s" class="tooltip"
902 title="%(title)s">%(label)s</a>''' % {
902 title="%(title)s">%(label)s</a>''' % {
903 'title': _('Click to select line'),
903 'title': _('Click to select line'),
904 'url': url,
904 'url': url,
905 'label': label
905 'label': label
906 }
906 }
907 else:
907 else:
908 return label
908 return label
909 if not self.parsed:
909 if not self.parsed:
910 self.prepare()
910 self.prepare()
911
911
912 diff_lines = self.parsed_diff
912 diff_lines = self.parsed_diff
913 if parsed_lines:
913 if parsed_lines:
914 diff_lines = parsed_lines
914 diff_lines = parsed_lines
915
915
916 _html_empty = True
916 _html_empty = True
917 _html = []
917 _html = []
918 _html.append('''<table class="%(table_class)s">\n''' % {
918 _html.append('''<table class="%(table_class)s">\n''' % {
919 'table_class': table_class
919 'table_class': table_class
920 })
920 })
921
921
922 for diff in diff_lines:
922 for diff in diff_lines:
923 for line in diff['chunks']:
923 for line in diff['chunks']:
924 _html_empty = False
924 _html_empty = False
925 for change in line:
925 for change in line:
926 _html.append('''<tr class="%(lc)s %(action)s">\n''' % {
926 _html.append('''<tr class="%(lc)s %(action)s">\n''' % {
927 'lc': line_class,
927 'lc': line_class,
928 'action': change['action']
928 'action': change['action']
929 })
929 })
930 anchor_old_id = ''
930 anchor_old_id = ''
931 anchor_new_id = ''
931 anchor_new_id = ''
932 anchor_old = "%(filename)s_o%(oldline_no)s" % {
932 anchor_old = "%(filename)s_o%(oldline_no)s" % {
933 'filename': self._safe_id(diff['filename']),
933 'filename': self._safe_id(diff['filename']),
934 'oldline_no': change['old_lineno']
934 'oldline_no': change['old_lineno']
935 }
935 }
936 anchor_new = "%(filename)s_n%(oldline_no)s" % {
936 anchor_new = "%(filename)s_n%(oldline_no)s" % {
937 'filename': self._safe_id(diff['filename']),
937 'filename': self._safe_id(diff['filename']),
938 'oldline_no': change['new_lineno']
938 'oldline_no': change['new_lineno']
939 }
939 }
940 cond_old = (change['old_lineno'] != '...' and
940 cond_old = (change['old_lineno'] != '...' and
941 change['old_lineno'])
941 change['old_lineno'])
942 cond_new = (change['new_lineno'] != '...' and
942 cond_new = (change['new_lineno'] != '...' and
943 change['new_lineno'])
943 change['new_lineno'])
944 if cond_old:
944 if cond_old:
945 anchor_old_id = 'id="%s"' % anchor_old
945 anchor_old_id = 'id="%s"' % anchor_old
946 if cond_new:
946 if cond_new:
947 anchor_new_id = 'id="%s"' % anchor_new
947 anchor_new_id = 'id="%s"' % anchor_new
948
948
949 if change['action'] != Action.CONTEXT:
949 if change['action'] != Action.CONTEXT:
950 anchor_link = True
950 anchor_link = True
951 else:
951 else:
952 anchor_link = False
952 anchor_link = False
953
953
954 ###########################################################
954 ###########################################################
955 # COMMENT ICONS
955 # COMMENT ICONS
956 ###########################################################
956 ###########################################################
957 _html.append('''\t<td class="add-comment-line"><span class="add-comment-content">''')
957 _html.append('''\t<td class="add-comment-line"><span class="add-comment-content">''')
958
958
959 if enable_comments and change['action'] != Action.CONTEXT:
959 if enable_comments and change['action'] != Action.CONTEXT:
960 _html.append('''<a href="#"><span class="icon-comment-add"></span></a>''')
960 _html.append('''<a href="#"><span class="icon-comment-add"></span></a>''')
961
961
962 _html.append('''</span></td><td class="comment-toggle tooltip" title="Toggle Comment Thread"><i class="icon-comment"></i></td>\n''')
962 _html.append('''</span></td><td class="comment-toggle tooltip" title="Toggle Comment Thread"><i class="icon-comment"></i></td>\n''')
963
963
964 ###########################################################
964 ###########################################################
965 # OLD LINE NUMBER
965 # OLD LINE NUMBER
966 ###########################################################
966 ###########################################################
967 _html.append('''\t<td %(a_id)s class="%(olc)s">''' % {
967 _html.append('''\t<td %(a_id)s class="%(olc)s">''' % {
968 'a_id': anchor_old_id,
968 'a_id': anchor_old_id,
969 'olc': old_lineno_class
969 'olc': old_lineno_class
970 })
970 })
971
971
972 _html.append('''%(link)s''' % {
972 _html.append('''%(link)s''' % {
973 'link': _link_to_if(anchor_link, change['old_lineno'],
973 'link': _link_to_if(anchor_link, change['old_lineno'],
974 '#%s' % anchor_old)
974 '#%s' % anchor_old)
975 })
975 })
976 _html.append('''</td>\n''')
976 _html.append('''</td>\n''')
977 ###########################################################
977 ###########################################################
978 # NEW LINE NUMBER
978 # NEW LINE NUMBER
979 ###########################################################
979 ###########################################################
980
980
981 _html.append('''\t<td %(a_id)s class="%(nlc)s">''' % {
981 _html.append('''\t<td %(a_id)s class="%(nlc)s">''' % {
982 'a_id': anchor_new_id,
982 'a_id': anchor_new_id,
983 'nlc': new_lineno_class
983 'nlc': new_lineno_class
984 })
984 })
985
985
986 _html.append('''%(link)s''' % {
986 _html.append('''%(link)s''' % {
987 'link': _link_to_if(anchor_link, change['new_lineno'],
987 'link': _link_to_if(anchor_link, change['new_lineno'],
988 '#%s' % anchor_new)
988 '#%s' % anchor_new)
989 })
989 })
990 _html.append('''</td>\n''')
990 _html.append('''</td>\n''')
991 ###########################################################
991 ###########################################################
992 # CODE
992 # CODE
993 ###########################################################
993 ###########################################################
994 code_classes = [code_class]
994 code_classes = [code_class]
995 if (not enable_comments or
995 if (not enable_comments or
996 change['action'] == Action.CONTEXT):
996 change['action'] == Action.CONTEXT):
997 code_classes.append('no-comment')
997 code_classes.append('no-comment')
998 _html.append('\t<td class="%s">' % ' '.join(code_classes))
998 _html.append('\t<td class="%s">' % ' '.join(code_classes))
999 _html.append('''\n\t\t<pre>%(code)s</pre>\n''' % {
999 _html.append('''\n\t\t<pre>%(code)s</pre>\n''' % {
1000 'code': change['line']
1000 'code': change['line']
1001 })
1001 })
1002
1002
1003 _html.append('''\t</td>''')
1003 _html.append('''\t</td>''')
1004 _html.append('''\n</tr>\n''')
1004 _html.append('''\n</tr>\n''')
1005 _html.append('''</table>''')
1005 _html.append('''</table>''')
1006 if _html_empty:
1006 if _html_empty:
1007 return None
1007 return None
1008 return ''.join(_html)
1008 return ''.join(_html)
1009
1009
1010 def stat(self):
1010 def stat(self):
1011 """
1011 """
1012 Returns tuple of added, and removed lines for this instance
1012 Returns tuple of added, and removed lines for this instance
1013 """
1013 """
1014 return self.adds, self.removes
1014 return self.adds, self.removes
1015
1015
1016 def get_context_of_line(
1016 def get_context_of_line(
1017 self, path, diff_line=None, context_before=3, context_after=3):
1017 self, path, diff_line=None, context_before=3, context_after=3):
1018 """
1018 """
1019 Returns the context lines for the specified diff line.
1019 Returns the context lines for the specified diff line.
1020
1020
1021 :type diff_line: :class:`DiffLineNumber`
1021 :type diff_line: :class:`DiffLineNumber`
1022 """
1022 """
1023 assert self.parsed, "DiffProcessor is not initialized."
1023 assert self.parsed, "DiffProcessor is not initialized."
1024
1024
1025 if None not in diff_line:
1025 if None not in diff_line:
1026 raise ValueError(
1026 raise ValueError(
1027 "Cannot specify both line numbers: {}".format(diff_line))
1027 "Cannot specify both line numbers: {}".format(diff_line))
1028
1028
1029 file_diff = self._get_file_diff(path)
1029 file_diff = self._get_file_diff(path)
1030 chunk, idx = self._find_chunk_line_index(file_diff, diff_line)
1030 chunk, idx = self._find_chunk_line_index(file_diff, diff_line)
1031
1031
1032 first_line_to_include = max(idx - context_before, 0)
1032 first_line_to_include = max(idx - context_before, 0)
1033 first_line_after_context = idx + context_after + 1
1033 first_line_after_context = idx + context_after + 1
1034 context_lines = chunk[first_line_to_include:first_line_after_context]
1034 context_lines = chunk[first_line_to_include:first_line_after_context]
1035
1035
1036 line_contents = [
1036 line_contents = [
1037 _context_line(line) for line in context_lines
1037 _context_line(line) for line in context_lines
1038 if _is_diff_content(line)]
1038 if _is_diff_content(line)]
1039 # TODO: johbo: Interim fixup, the diff chunks drop the final newline.
1039 # TODO: johbo: Interim fixup, the diff chunks drop the final newline.
1040 # Once they are fixed, we can drop this line here.
1040 # Once they are fixed, we can drop this line here.
1041 if line_contents:
1041 if line_contents:
1042 line_contents[-1] = (
1042 line_contents[-1] = (
1043 line_contents[-1][0], line_contents[-1][1].rstrip('\n') + '\n')
1043 line_contents[-1][0], line_contents[-1][1].rstrip('\n') + '\n')
1044 return line_contents
1044 return line_contents
1045
1045
1046 def find_context(self, path, context, offset=0):
1046 def find_context(self, path, context, offset=0):
1047 """
1047 """
1048 Finds the given `context` inside of the diff.
1048 Finds the given `context` inside of the diff.
1049
1049
1050 Use the parameter `offset` to specify which offset the target line has
1050 Use the parameter `offset` to specify which offset the target line has
1051 inside of the given `context`. This way the correct diff line will be
1051 inside of the given `context`. This way the correct diff line will be
1052 returned.
1052 returned.
1053
1053
1054 :param offset: Shall be used to specify the offset of the main line
1054 :param offset: Shall be used to specify the offset of the main line
1055 within the given `context`.
1055 within the given `context`.
1056 """
1056 """
1057 if offset < 0 or offset >= len(context):
1057 if offset < 0 or offset >= len(context):
1058 raise ValueError(
1058 raise ValueError(
1059 "Only positive values up to the length of the context "
1059 "Only positive values up to the length of the context "
1060 "minus one are allowed.")
1060 "minus one are allowed.")
1061
1061
1062 matches = []
1062 matches = []
1063 file_diff = self._get_file_diff(path)
1063 file_diff = self._get_file_diff(path)
1064
1064
1065 for chunk in file_diff['chunks']:
1065 for chunk in file_diff['chunks']:
1066 context_iter = iter(context)
1066 context_iter = iter(context)
1067 for line_idx, line in enumerate(chunk):
1067 for line_idx, line in enumerate(chunk):
1068 try:
1068 try:
1069 if _context_line(line) == context_iter.next():
1069 if _context_line(line) == next(context_iter):
1070 continue
1070 continue
1071 except StopIteration:
1071 except StopIteration:
1072 matches.append((line_idx, chunk))
1072 matches.append((line_idx, chunk))
1073 context_iter = iter(context)
1073 context_iter = iter(context)
1074
1074
1075 # Increment position and triger StopIteration
1075 # Increment position and triger StopIteration
1076 # if we had a match at the end
1076 # if we had a match at the end
1077 line_idx += 1
1077 line_idx += 1
1078 try:
1078 try:
1079 context_iter.next()
1079 next(context_iter)
1080 except StopIteration:
1080 except StopIteration:
1081 matches.append((line_idx, chunk))
1081 matches.append((line_idx, chunk))
1082
1082
1083 effective_offset = len(context) - offset
1083 effective_offset = len(context) - offset
1084 found_at_diff_lines = [
1084 found_at_diff_lines = [
1085 _line_to_diff_line_number(chunk[idx - effective_offset])
1085 _line_to_diff_line_number(chunk[idx - effective_offset])
1086 for idx, chunk in matches]
1086 for idx, chunk in matches]
1087
1087
1088 return found_at_diff_lines
1088 return found_at_diff_lines
1089
1089
1090 def _get_file_diff(self, path):
1090 def _get_file_diff(self, path):
1091 for file_diff in self.parsed_diff:
1091 for file_diff in self.parsed_diff:
1092 if file_diff['filename'] == path:
1092 if file_diff['filename'] == path:
1093 break
1093 break
1094 else:
1094 else:
1095 raise FileNotInDiffException("File {} not in diff".format(path))
1095 raise FileNotInDiffException("File {} not in diff".format(path))
1096 return file_diff
1096 return file_diff
1097
1097
1098 def _find_chunk_line_index(self, file_diff, diff_line):
1098 def _find_chunk_line_index(self, file_diff, diff_line):
1099 for chunk in file_diff['chunks']:
1099 for chunk in file_diff['chunks']:
1100 for idx, line in enumerate(chunk):
1100 for idx, line in enumerate(chunk):
1101 if line['old_lineno'] == diff_line.old:
1101 if line['old_lineno'] == diff_line.old:
1102 return chunk, idx
1102 return chunk, idx
1103 if line['new_lineno'] == diff_line.new:
1103 if line['new_lineno'] == diff_line.new:
1104 return chunk, idx
1104 return chunk, idx
1105 raise LineNotInDiffException(
1105 raise LineNotInDiffException(
1106 "The line {} is not part of the diff.".format(diff_line))
1106 "The line {} is not part of the diff.".format(diff_line))
1107
1107
1108
1108
1109 def _is_diff_content(line):
1109 def _is_diff_content(line):
1110 return line['action'] in (
1110 return line['action'] in (
1111 Action.UNMODIFIED, Action.ADD, Action.DELETE)
1111 Action.UNMODIFIED, Action.ADD, Action.DELETE)
1112
1112
1113
1113
1114 def _context_line(line):
1114 def _context_line(line):
1115 return (line['action'], line['line'])
1115 return (line['action'], line['line'])
1116
1116
1117
1117
1118 DiffLineNumber = collections.namedtuple('DiffLineNumber', ['old', 'new'])
1118 DiffLineNumber = collections.namedtuple('DiffLineNumber', ['old', 'new'])
1119
1119
1120
1120
1121 def _line_to_diff_line_number(line):
1121 def _line_to_diff_line_number(line):
1122 new_line_no = line['new_lineno'] or None
1122 new_line_no = line['new_lineno'] or None
1123 old_line_no = line['old_lineno'] or None
1123 old_line_no = line['old_lineno'] or None
1124 return DiffLineNumber(old=old_line_no, new=new_line_no)
1124 return DiffLineNumber(old=old_line_no, new=new_line_no)
1125
1125
1126
1126
1127 class FileNotInDiffException(Exception):
1127 class FileNotInDiffException(Exception):
1128 """
1128 """
1129 Raised when the context for a missing file is requested.
1129 Raised when the context for a missing file is requested.
1130
1130
1131 If you request the context for a line in a file which is not part of the
1131 If you request the context for a line in a file which is not part of the
1132 given diff, then this exception is raised.
1132 given diff, then this exception is raised.
1133 """
1133 """
1134
1134
1135
1135
1136 class LineNotInDiffException(Exception):
1136 class LineNotInDiffException(Exception):
1137 """
1137 """
1138 Raised when the context for a missing line is requested.
1138 Raised when the context for a missing line is requested.
1139
1139
1140 If you request the context for a line in a file and this line is not
1140 If you request the context for a line in a file and this line is not
1141 part of the given diff, then this exception is raised.
1141 part of the given diff, then this exception is raised.
1142 """
1142 """
1143
1143
1144
1144
1145 class DiffLimitExceeded(Exception):
1145 class DiffLimitExceeded(Exception):
1146 pass
1146 pass
1147
1147
1148
1148
1149 # NOTE(marcink): if diffs.mako change, probably this
1149 # NOTE(marcink): if diffs.mako change, probably this
1150 # needs a bump to next version
1150 # needs a bump to next version
1151 CURRENT_DIFF_VERSION = 'v5'
1151 CURRENT_DIFF_VERSION = 'v5'
1152
1152
1153
1153
1154 def _cleanup_cache_file(cached_diff_file):
1154 def _cleanup_cache_file(cached_diff_file):
1155 # cleanup file to not store it "damaged"
1155 # cleanup file to not store it "damaged"
1156 try:
1156 try:
1157 os.remove(cached_diff_file)
1157 os.remove(cached_diff_file)
1158 except Exception:
1158 except Exception:
1159 log.exception('Failed to cleanup path %s', cached_diff_file)
1159 log.exception('Failed to cleanup path %s', cached_diff_file)
1160
1160
1161
1161
1162 def _get_compression_mode(cached_diff_file):
1162 def _get_compression_mode(cached_diff_file):
1163 mode = 'bz2'
1163 mode = 'bz2'
1164 if 'mode:plain' in cached_diff_file:
1164 if 'mode:plain' in cached_diff_file:
1165 mode = 'plain'
1165 mode = 'plain'
1166 elif 'mode:gzip' in cached_diff_file:
1166 elif 'mode:gzip' in cached_diff_file:
1167 mode = 'gzip'
1167 mode = 'gzip'
1168 return mode
1168 return mode
1169
1169
1170
1170
1171 def cache_diff(cached_diff_file, diff, commits):
1171 def cache_diff(cached_diff_file, diff, commits):
1172 compression_mode = _get_compression_mode(cached_diff_file)
1172 compression_mode = _get_compression_mode(cached_diff_file)
1173
1173
1174 struct = {
1174 struct = {
1175 'version': CURRENT_DIFF_VERSION,
1175 'version': CURRENT_DIFF_VERSION,
1176 'diff': diff,
1176 'diff': diff,
1177 'commits': commits
1177 'commits': commits
1178 }
1178 }
1179
1179
1180 start = time.time()
1180 start = time.time()
1181 try:
1181 try:
1182 if compression_mode == 'plain':
1182 if compression_mode == 'plain':
1183 with open(cached_diff_file, 'wb') as f:
1183 with open(cached_diff_file, 'wb') as f:
1184 pickle.dump(struct, f)
1184 pickle.dump(struct, f)
1185 elif compression_mode == 'gzip':
1185 elif compression_mode == 'gzip':
1186 with gzip.GzipFile(cached_diff_file, 'wb') as f:
1186 with gzip.GzipFile(cached_diff_file, 'wb') as f:
1187 pickle.dump(struct, f)
1187 pickle.dump(struct, f)
1188 else:
1188 else:
1189 with bz2.BZ2File(cached_diff_file, 'wb') as f:
1189 with bz2.BZ2File(cached_diff_file, 'wb') as f:
1190 pickle.dump(struct, f)
1190 pickle.dump(struct, f)
1191 except Exception:
1191 except Exception:
1192 log.warn('Failed to save cache', exc_info=True)
1192 log.warn('Failed to save cache', exc_info=True)
1193 _cleanup_cache_file(cached_diff_file)
1193 _cleanup_cache_file(cached_diff_file)
1194
1194
1195 log.debug('Saved diff cache under %s in %.4fs', cached_diff_file, time.time() - start)
1195 log.debug('Saved diff cache under %s in %.4fs', cached_diff_file, time.time() - start)
1196
1196
1197
1197
1198 def load_cached_diff(cached_diff_file):
1198 def load_cached_diff(cached_diff_file):
1199 compression_mode = _get_compression_mode(cached_diff_file)
1199 compression_mode = _get_compression_mode(cached_diff_file)
1200
1200
1201 default_struct = {
1201 default_struct = {
1202 'version': CURRENT_DIFF_VERSION,
1202 'version': CURRENT_DIFF_VERSION,
1203 'diff': None,
1203 'diff': None,
1204 'commits': None
1204 'commits': None
1205 }
1205 }
1206
1206
1207 has_cache = os.path.isfile(cached_diff_file)
1207 has_cache = os.path.isfile(cached_diff_file)
1208 if not has_cache:
1208 if not has_cache:
1209 log.debug('Reading diff cache file failed %s', cached_diff_file)
1209 log.debug('Reading diff cache file failed %s', cached_diff_file)
1210 return default_struct
1210 return default_struct
1211
1211
1212 data = None
1212 data = None
1213
1213
1214 start = time.time()
1214 start = time.time()
1215 try:
1215 try:
1216 if compression_mode == 'plain':
1216 if compression_mode == 'plain':
1217 with open(cached_diff_file, 'rb') as f:
1217 with open(cached_diff_file, 'rb') as f:
1218 data = pickle.load(f)
1218 data = pickle.load(f)
1219 elif compression_mode == 'gzip':
1219 elif compression_mode == 'gzip':
1220 with gzip.GzipFile(cached_diff_file, 'rb') as f:
1220 with gzip.GzipFile(cached_diff_file, 'rb') as f:
1221 data = pickle.load(f)
1221 data = pickle.load(f)
1222 else:
1222 else:
1223 with bz2.BZ2File(cached_diff_file, 'rb') as f:
1223 with bz2.BZ2File(cached_diff_file, 'rb') as f:
1224 data = pickle.load(f)
1224 data = pickle.load(f)
1225 except Exception:
1225 except Exception:
1226 log.warn('Failed to read diff cache file', exc_info=True)
1226 log.warn('Failed to read diff cache file', exc_info=True)
1227
1227
1228 if not data:
1228 if not data:
1229 data = default_struct
1229 data = default_struct
1230
1230
1231 if not isinstance(data, dict):
1231 if not isinstance(data, dict):
1232 # old version of data ?
1232 # old version of data ?
1233 data = default_struct
1233 data = default_struct
1234
1234
1235 # check version
1235 # check version
1236 if data.get('version') != CURRENT_DIFF_VERSION:
1236 if data.get('version') != CURRENT_DIFF_VERSION:
1237 # purge cache
1237 # purge cache
1238 _cleanup_cache_file(cached_diff_file)
1238 _cleanup_cache_file(cached_diff_file)
1239 return default_struct
1239 return default_struct
1240
1240
1241 log.debug('Loaded diff cache from %s in %.4fs', cached_diff_file, time.time() - start)
1241 log.debug('Loaded diff cache from %s in %.4fs', cached_diff_file, time.time() - start)
1242
1242
1243 return data
1243 return data
1244
1244
1245
1245
1246 def generate_diff_cache_key(*args):
1246 def generate_diff_cache_key(*args):
1247 """
1247 """
1248 Helper to generate a cache key using arguments
1248 Helper to generate a cache key using arguments
1249 """
1249 """
1250 def arg_mapper(input_param):
1250 def arg_mapper(input_param):
1251 input_param = safe_str(input_param)
1251 input_param = safe_str(input_param)
1252 # we cannot allow '/' in arguments since it would allow
1252 # we cannot allow '/' in arguments since it would allow
1253 # subdirectory usage
1253 # subdirectory usage
1254 input_param.replace('/', '_')
1254 input_param.replace('/', '_')
1255 return input_param or None # prevent empty string arguments
1255 return input_param or None # prevent empty string arguments
1256
1256
1257 return '_'.join([
1257 return '_'.join([
1258 '{}' for i in range(len(args))]).format(*map(arg_mapper, args))
1258 '{}' for i in range(len(args))]).format(*map(arg_mapper, args))
1259
1259
1260
1260
1261 def diff_cache_exist(cache_storage, *args):
1261 def diff_cache_exist(cache_storage, *args):
1262 """
1262 """
1263 Based on all generated arguments check and return a cache path
1263 Based on all generated arguments check and return a cache path
1264 """
1264 """
1265 args = list(args) + ['mode:gzip']
1265 args = list(args) + ['mode:gzip']
1266 cache_key = generate_diff_cache_key(*args)
1266 cache_key = generate_diff_cache_key(*args)
1267 cache_file_path = os.path.join(cache_storage, cache_key)
1267 cache_file_path = os.path.join(cache_storage, cache_key)
1268 # prevent path traversal attacks using some param that have e.g '../../'
1268 # prevent path traversal attacks using some param that have e.g '../../'
1269 if not os.path.abspath(cache_file_path).startswith(cache_storage):
1269 if not os.path.abspath(cache_file_path).startswith(cache_storage):
1270 raise ValueError('Final path must be within {}'.format(cache_storage))
1270 raise ValueError('Final path must be within {}'.format(cache_storage))
1271
1271
1272 return cache_file_path
1272 return cache_file_path
General Comments 0
You need to be logged in to leave comments. Login now