##// END OF EJS Templates
python3: 2to3 fixes
super-admin -
r4930:f88262ff default
parent child Browse files
Show More
@@ -1,1764 +1,1761 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 import collections
4 4 import copy
5 5 import datetime
6 6 import hashlib
7 7 import hmac
8 8 import json
9 9 import logging
10 try:
11 import cPickle as pickle
12 except ImportError:
13 import pickle
10 import pickle
14 11 import sys
15 12 import threading
16 13 import time
17 14 from xml.etree import ElementTree
18 15
19 16 from authomatic.exceptions import (
20 17 ConfigError,
21 18 CredentialsError,
22 19 ImportStringError,
23 20 RequestElementsError,
24 21 SessionError,
25 22 )
26 23 from authomatic import six
27 24 from authomatic.six.moves import urllib_parse as parse
28 25
29 26
30 27 # =========================================================================
31 28 # Global variables !!!
32 29 # =========================================================================
33 30
34 31 _logger = logging.getLogger(__name__)
35 32 _logger.addHandler(logging.StreamHandler(sys.stdout))
36 33
37 34 _counter = None
38 35
39 36
40 37 def normalize_dict(dict_):
41 38 """
42 39 Replaces all values that are single-item iterables with the value of its
43 40 index 0.
44 41
45 42 :param dict dict_:
46 43 Dictionary to normalize.
47 44
48 45 :returns:
49 46 Normalized dictionary.
50 47
51 48 """
52 49
53 50 return dict([(k, v[0] if not isinstance(v, str) and len(v) == 1 else v)
54 51 for k, v in list(dict_.items())])
55 52
56 53
57 54 def items_to_dict(items):
58 55 """
59 56 Converts list of tuples to dictionary with duplicate keys converted to
60 57 lists.
61 58
62 59 :param list items:
63 60 List of tuples.
64 61
65 62 :returns:
66 63 :class:`dict`
67 64
68 65 """
69 66
70 67 res = collections.defaultdict(list)
71 68
72 69 for k, v in items:
73 70 res[k].append(v)
74 71
75 72 return normalize_dict(dict(res))
76 73
77 74
78 75 class Counter(object):
79 76 """
80 77 A simple counter to be used in the config to generate unique `id` values.
81 78 """
82 79
83 80 def __init__(self, start=0):
84 81 self._count = start
85 82
86 83 def count(self):
87 84 self._count += 1
88 85 return self._count
89 86
90 87
91 88 _counter = Counter()
92 89
93 90
94 91 def provider_id():
95 92 """
96 93 A simple counter to be used in the config to generate unique `IDs`.
97 94
98 95 :returns:
99 96 :class:`int`.
100 97
101 98 Use it in the :doc:`config` like this:
102 99 ::
103 100
104 101 import authomatic
105 102
106 103 CONFIG = {
107 104 'facebook': {
108 105 'class_': authomatic.providers.oauth2.Facebook,
109 106 'id': authomatic.provider_id(), # returns 1
110 107 'consumer_key': '##########',
111 108 'consumer_secret': '##########',
112 109 'scope': ['user_about_me', 'email']
113 110 },
114 111 'google': {
115 112 'class_': 'authomatic.providers.oauth2.Google',
116 113 'id': authomatic.provider_id(), # returns 2
117 114 'consumer_key': '##########',
118 115 'consumer_secret': '##########',
119 116 'scope': ['https://www.googleapis.com/auth/userinfo.profile',
120 117 'https://www.googleapis.com/auth/userinfo.email']
121 118 },
122 119 'windows_live': {
123 120 'class_': 'oauth2.WindowsLive',
124 121 'id': authomatic.provider_id(), # returns 3
125 122 'consumer_key': '##########',
126 123 'consumer_secret': '##########',
127 124 'scope': ['wl.basic', 'wl.emails', 'wl.photos']
128 125 },
129 126 }
130 127
131 128 """
132 129
133 130 return _counter.count()
134 131
135 132
136 133 def escape(s):
137 134 """
138 135 Escape a URL including any /.
139 136 """
140 137 return parse.quote(s.encode('utf-8'), safe='~')
141 138
142 139
143 140 def json_qs_parser(body):
144 141 """
145 142 Parses response body from JSON, XML or query string.
146 143
147 144 :param body:
148 145 string
149 146
150 147 :returns:
151 148 :class:`dict`, :class:`list` if input is JSON or query string,
152 149 :class:`xml.etree.ElementTree.Element` if XML.
153 150
154 151 """
155 152 try:
156 153 # Try JSON first.
157 154 return json.loads(body)
158 155 except (OverflowError, TypeError, ValueError):
159 156 pass
160 157
161 158 try:
162 159 # Then XML.
163 160 return ElementTree.fromstring(body)
164 161 except (ElementTree.ParseError, TypeError, ValueError):
165 162 pass
166 163
167 164 # Finally query string.
168 165 return dict(parse.parse_qsl(body))
169 166
170 167
171 168 def import_string(import_name, silent=False):
172 169 """
173 170 Imports an object by string in dotted notation.
174 171
175 172 taken `from webapp2.import_string() <http://webapp-
176 173 improved.appspot.com/api/webapp2.html#webapp2.import_string>`_
177 174
178 175 """
179 176
180 177 try:
181 178 if '.' in import_name:
182 179 module, obj = import_name.rsplit('.', 1)
183 180 return getattr(__import__(module, None, None, [obj]), obj)
184 181 else:
185 182 return __import__(import_name)
186 183 except (ImportError, AttributeError) as e:
187 184 if not silent:
188 185 raise ImportStringError('Import from string failed for path {0}'
189 186 .format(import_name), str(e))
190 187
191 188
192 189 def resolve_provider_class(class_):
193 190 """
194 191 Returns a provider class.
195 192
196 193 :param class_name: :class:`string` or
197 194 :class:`authomatic.providers.BaseProvider` subclass.
198 195
199 196 """
200 197
201 198 if isinstance(class_, str):
202 199 # prepare path for authomatic.providers package
203 200 path = '.'.join([__package__, 'providers', class_])
204 201
205 202 # try to import class by string from providers module or by fully
206 203 # qualified path
207 204 return import_string(class_, True) or import_string(path)
208 205 else:
209 206 return class_
210 207
211 208
212 209 def id_to_name(config, short_name):
213 210 """
214 211 Returns the provider :doc:`config` key based on it's ``id`` value.
215 212
216 213 :param dict config:
217 214 :doc:`config`.
218 215 :param id:
219 216 Value of the id parameter in the :ref:`config` to search for.
220 217
221 218 """
222 219
223 220 for k, v in list(config.items()):
224 221 if v.get('id') == short_name:
225 222 return k
226 223
227 224 raise Exception(
228 225 'No provider with id={0} found in the config!'.format(short_name))
229 226
230 227
231 228 class ReprMixin(object):
232 229 """
233 230 Provides __repr__() method with output *ClassName(arg1=value, arg2=value)*.
234 231
235 232 Ignored are attributes
236 233
237 234 * which values are considered false.
238 235 * with leading underscore.
239 236 * listed in _repr_ignore.
240 237
241 238 Values of attributes listed in _repr_sensitive will be replaced by *###*.
242 239 Values which repr() string is longer than _repr_length_limit will be
243 240 represented as *ClassName(...)*
244 241
245 242 """
246 243
247 244 #: Iterable of attributes to be ignored.
248 245 _repr_ignore = []
249 246 #: Iterable of attributes which value should not be visible.
250 247 _repr_sensitive = []
251 248 #: `int` Values longer than this will be truncated to *ClassName(...)*.
252 249 _repr_length_limit = 20
253 250
254 251 def __repr__(self):
255 252
256 253 # get class name
257 254 name = self.__class__.__name__
258 255
259 256 # construct keyword arguments
260 257 args = []
261 258
262 259 for k, v in list(self.__dict__.items()):
263 260
264 261 # ignore attributes with leading underscores and those listed in
265 262 # _repr_ignore
266 263 if v and not k.startswith('_') and k not in self._repr_ignore:
267 264
268 265 # replace sensitive values
269 266 if k in self._repr_sensitive:
270 267 v = '###'
271 268
272 269 # if repr is too long
273 270 if len(repr(v)) > self._repr_length_limit:
274 271 # Truncate to ClassName(...)
275 272 v = '{0}(...)'.format(v.__class__.__name__)
276 273 else:
277 274 v = repr(v)
278 275
279 276 args.append('{0}={1}'.format(k, v))
280 277
281 278 return '{0}({1})'.format(name, ', '.join(args))
282 279
283 280
284 281 class Future(threading.Thread):
285 282 """
286 283 Represents an activity run in a separate thread. Subclasses the standard
287 284 library :class:`threading.Thread` and adds :attr:`.get_result` method.
288 285
289 286 .. warning::
290 287
291 288 |async|
292 289
293 290 """
294 291
295 292 def __init__(self, func, *args, **kwargs):
296 293 """
297 294 :param callable func:
298 295 The function to be run in separate thread.
299 296
300 297 Calls :data:`func` in separate thread and returns immediately.
301 298 Accepts arbitrary positional and keyword arguments which will be
302 299 passed to :data:`func`.
303 300 """
304 301
305 302 super(Future, self).__init__()
306 303 self._func = func
307 304 self._args = args
308 305 self._kwargs = kwargs
309 306 self._result = None
310 307
311 308 self.start()
312 309
313 310 def run(self):
314 311 self._result = self._func(*self._args, **self._kwargs)
315 312
316 313 def get_result(self, timeout=None):
317 314 """
318 315 Waits for the wrapped :data:`func` to finish and returns its result.
319 316
320 317 .. note::
321 318
322 319 This will block the **calling thread** until the :data:`func`
323 320 returns.
324 321
325 322 :param timeout:
326 323 :class:`float` or ``None`` A timeout for the :data:`func` to
327 324 return in seconds.
328 325
329 326 :returns:
330 327 The result of the wrapped :data:`func`.
331 328
332 329 """
333 330
334 331 self.join(timeout)
335 332 return self._result
336 333
337 334
338 335 class Session(object):
339 336 """
340 337 A dictionary-like secure cookie session implementation.
341 338 """
342 339
343 340 def __init__(self, adapter, secret, name='authomatic', max_age=600,
344 341 secure=False):
345 342 """
346 343 :param str secret:
347 344 Session secret used to sign the session cookie.
348 345 :param str name:
349 346 Session cookie name.
350 347 :param int max_age:
351 348 Maximum allowed age of session cookie nonce in seconds.
352 349 :param bool secure:
353 350 If ``True`` the session cookie will be saved with ``Secure``
354 351 attribute.
355 352 """
356 353
357 354 self.adapter = adapter
358 355 self.name = name
359 356 self.secret = secret
360 357 self.max_age = max_age
361 358 self.secure = secure
362 359 self._data = {}
363 360
364 361 def create_cookie(self, delete=None):
365 362 """
366 363 Creates the value for ``Set-Cookie`` HTTP header.
367 364
368 365 :param bool delete:
369 366 If ``True`` the cookie value will be ``deleted`` and the
370 367 Expires value will be ``Thu, 01-Jan-1970 00:00:01 GMT``.
371 368
372 369 """
373 370 value = 'deleted' if delete else self._serialize(self.data)
374 371 split_url = parse.urlsplit(self.adapter.url)
375 372 domain = split_url.netloc.split(':')[0]
376 373
377 374 # Work-around for issue #11, failure of WebKit-based browsers to accept
378 375 # cookies set as part of a redirect response in some circumstances.
379 376 if '.' not in domain:
380 377 template = '{name}={value}; Path={path}; HttpOnly{secure}{expires}'
381 378 else:
382 379 template = ('{name}={value}; Domain={domain}; Path={path}; '
383 380 'HttpOnly{secure}{expires}')
384 381
385 382 return template.format(
386 383 name=self.name,
387 384 value=value,
388 385 domain=domain,
389 386 path=split_url.path,
390 387 secure='; Secure' if self.secure else '',
391 388 expires='; Expires=Thu, 01-Jan-1970 00:00:01 GMT' if delete else ''
392 389 )
393 390
394 391 def save(self):
395 392 """
396 393 Adds the session cookie to headers.
397 394 """
398 395 if self.data:
399 396 cookie = self.create_cookie()
400 397 cookie_len = len(cookie)
401 398
402 399 if cookie_len > 4093:
403 400 raise SessionError('Cookie too long! The cookie size {0} '
404 401 'is more than 4093 bytes.'
405 402 .format(cookie_len))
406 403
407 404 self.adapter.set_header('Set-Cookie', cookie)
408 405
409 406 # Reset data
410 407 self._data = {}
411 408
412 409 def delete(self):
413 410 self.adapter.set_header('Set-Cookie', self.create_cookie(delete=True))
414 411
415 412 def _get_data(self):
416 413 """
417 414 Extracts the session data from cookie.
418 415 """
419 416 cookie = self.adapter.cookies.get(self.name)
420 417 return self._deserialize(cookie) if cookie else {}
421 418
422 419 @property
423 420 def data(self):
424 421 """
425 422 Gets session data lazily.
426 423 """
427 424 if not self._data:
428 425 self._data = self._get_data()
429 426 # Always return a dict, even if deserialization returned nothing
430 427 if self._data is None:
431 428 self._data = {}
432 429 return self._data
433 430
434 431 def _signature(self, *parts):
435 432 """
436 433 Creates signature for the session.
437 434 """
438 435 signature = hmac.new(six.b(self.secret), digestmod=hashlib.sha1)
439 436 signature.update(six.b('|'.join(parts)))
440 437 return signature.hexdigest()
441 438
442 439 def _serialize(self, value):
443 440 """
444 441 Converts the value to a signed string with timestamp.
445 442
446 443 :param value:
447 444 Object to be serialized.
448 445
449 446 :returns:
450 447 Serialized value.
451 448
452 449 """
453 450
454 451 # data = copy.deepcopy(value)
455 452 data = value
456 453
457 454 # 1. Serialize
458 455 serialized = pickle.dumps(data).decode('latin-1')
459 456
460 457 # 2. Encode
461 458 # Percent encoding produces smaller result then urlsafe base64.
462 459 encoded = parse.quote(serialized, '')
463 460
464 461 # 3. Concatenate
465 462 timestamp = str(int(time.time()))
466 463 signature = self._signature(self.name, encoded, timestamp)
467 464 concatenated = '|'.join([encoded, timestamp, signature])
468 465
469 466 return concatenated
470 467
471 468 def _deserialize(self, value):
472 469 """
473 470 Deserializes and verifies the value created by :meth:`._serialize`.
474 471
475 472 :param str value:
476 473 The serialized value.
477 474
478 475 :returns:
479 476 Deserialized object.
480 477
481 478 """
482 479
483 480 # 3. Split
484 481 encoded, timestamp, signature = value.split('|')
485 482
486 483 # Verify signature
487 484 if not signature == self._signature(self.name, encoded, timestamp):
488 485 raise SessionError('Invalid signature "{0}"!'.format(signature))
489 486
490 487 # Verify timestamp
491 488 if int(timestamp) < int(time.time()) - self.max_age:
492 489 return None
493 490
494 491 # 2. Decode
495 492 decoded = parse.unquote(encoded)
496 493
497 494 # 1. Deserialize
498 495 deserialized = pickle.loads(decoded.encode('latin-1'))
499 496
500 497 return deserialized
501 498
502 499 def __setitem__(self, key, value):
503 500 self._data[key] = value
504 501
505 502 def __getitem__(self, key):
506 503 return self.data.__getitem__(key)
507 504
508 505 def __delitem__(self, key):
509 506 return self._data.__delitem__(key)
510 507
511 508 def get(self, key, default=None):
512 509 return self.data.get(key, default)
513 510
514 511
515 512 class User(ReprMixin):
516 513 """
517 514 Provides unified interface to selected **user** info returned by different
518 515 **providers**.
519 516
520 517 .. note:: The value format may vary across providers.
521 518
522 519 """
523 520
524 521 def __init__(self, provider, **kwargs):
525 522 #: A :doc:`provider <providers>` instance.
526 523 self.provider = provider
527 524
528 525 #: An :class:`.Credentials` instance.
529 526 self.credentials = kwargs.get('credentials')
530 527
531 528 #: A :class:`dict` containing all the **user** information returned
532 529 #: by the **provider**.
533 530 #: The structure differs across **providers**.
534 531 self.data = kwargs.get('data')
535 532
536 533 #: The :attr:`.Response.content` of the request made to update
537 534 #: the user.
538 535 self.content = kwargs.get('content')
539 536
540 537 #: :class:`str` ID assigned to the **user** by the **provider**.
541 538 self.id = kwargs.get('id')
542 539 #: :class:`str` User name e.g. *andrewpipkin*.
543 540 self.username = kwargs.get('username')
544 541 #: :class:`str` Name e.g. *Andrew Pipkin*.
545 542 self.name = kwargs.get('name')
546 543 #: :class:`str` First name e.g. *Andrew*.
547 544 self.first_name = kwargs.get('first_name')
548 545 #: :class:`str` Last name e.g. *Pipkin*.
549 546 self.last_name = kwargs.get('last_name')
550 547 #: :class:`str` Nickname e.g. *Andy*.
551 548 self.nickname = kwargs.get('nickname')
552 549 #: :class:`str` Link URL.
553 550 self.link = kwargs.get('link')
554 551 #: :class:`str` Gender.
555 552 self.gender = kwargs.get('gender')
556 553 #: :class:`str` Timezone.
557 554 self.timezone = kwargs.get('timezone')
558 555 #: :class:`str` Locale.
559 556 self.locale = kwargs.get('locale')
560 557 #: :class:`str` E-mail.
561 558 self.email = kwargs.get('email')
562 559 #: :class:`str` phone.
563 560 self.phone = kwargs.get('phone')
564 561 #: :class:`str` Picture URL.
565 562 self.picture = kwargs.get('picture')
566 563 #: Birth date as :class:`datetime.datetime()` or :class:`str`
567 564 # if parsing failed or ``None``.
568 565 self.birth_date = kwargs.get('birth_date')
569 566 #: :class:`str` Country.
570 567 self.country = kwargs.get('country')
571 568 #: :class:`str` City.
572 569 self.city = kwargs.get('city')
573 570 #: :class:`str` Geographical location.
574 571 self.location = kwargs.get('location')
575 572 #: :class:`str` Postal code.
576 573 self.postal_code = kwargs.get('postal_code')
577 574 #: Instance of the Google App Engine Users API
578 575 #: `User <https://developers.google.com/appengine/docs/python/users/userclass>`_ class.
579 576 #: Only present when using the :class:`authomatic.providers.gaeopenid.GAEOpenID` provider.
580 577 self.gae_user = kwargs.get('gae_user')
581 578
582 579 def update(self):
583 580 """
584 581 Updates the user info by fetching the **provider's** user info URL.
585 582
586 583 :returns:
587 584 Updated instance of this class.
588 585
589 586 """
590 587
591 588 return self.provider.update_user()
592 589
593 590 def async_update(self):
594 591 """
595 592 Same as :meth:`.update` but runs asynchronously in a separate thread.
596 593
597 594 .. warning::
598 595
599 596 |async|
600 597
601 598 :returns:
602 599 :class:`.Future` instance representing the separate thread.
603 600
604 601 """
605 602
606 603 return Future(self.update)
607 604
608 605 def to_dict(self):
609 606 """
610 607 Converts the :class:`.User` instance to a :class:`dict`.
611 608
612 609 :returns:
613 610 :class:`dict`
614 611
615 612 """
616 613
617 614 # copy the dictionary
618 615 d = copy.copy(self.__dict__)
619 616
620 617 # Keep only the provider name to avoid circular reference
621 618 d['provider'] = self.provider.name
622 619 d['credentials'] = self.credentials.serialize(
623 620 ) if self.credentials else None
624 621 d['birth_date'] = str(d['birth_date'])
625 622
626 623 # Remove content
627 624 d.pop('content')
628 625
629 626 if isinstance(self.data, ElementTree.Element):
630 627 d['data'] = None
631 628
632 629 return d
633 630
634 631
635 632 SupportedUserAttributesNT = collections.namedtuple(
636 633 typename='SupportedUserAttributesNT',
637 634 field_names=['birth_date', 'city', 'country', 'email', 'first_name',
638 635 'gender', 'id', 'last_name', 'link', 'locale', 'location',
639 636 'name', 'nickname', 'phone', 'picture', 'postal_code',
640 637 'timezone', 'username', ]
641 638 )
642 639
643 640
644 641 class SupportedUserAttributes(SupportedUserAttributesNT):
645 642 def __new__(cls, **kwargs):
646 643 defaults = dict((i, False) for i in SupportedUserAttributes._fields) # pylint:disable=no-member
647 644 defaults.update(**kwargs)
648 645 return super(SupportedUserAttributes, cls).__new__(cls, **defaults)
649 646
650 647
651 648 class Credentials(ReprMixin):
652 649 """
653 650 Contains all necessary information to fetch **user's protected resources**.
654 651 """
655 652
656 653 _repr_sensitive = ('token', 'refresh_token', 'token_secret',
657 654 'consumer_key', 'consumer_secret')
658 655
659 656 def __init__(self, config, **kwargs):
660 657
661 658 #: :class:`dict` :doc:`config`.
662 659 self.config = config
663 660
664 661 #: :class:`str` User **access token**.
665 662 self.token = kwargs.get('token', '')
666 663
667 664 #: :class:`str` Access token type.
668 665 self.token_type = kwargs.get('token_type', '')
669 666
670 667 #: :class:`str` Refresh token.
671 668 self.refresh_token = kwargs.get('refresh_token', '')
672 669
673 670 #: :class:`str` Access token secret.
674 671 self.token_secret = kwargs.get('token_secret', '')
675 672
676 673 #: :class:`int` Expiration date as UNIX timestamp.
677 674 self.expiration_time = int(kwargs.get('expiration_time', 0))
678 675
679 676 #: A :doc:`Provider <providers>` instance**.
680 677 provider = kwargs.get('provider')
681 678
682 679 self.expire_in = int(kwargs.get('expire_in', 0))
683 680
684 681 if provider:
685 682 #: :class:`str` Provider name specified in the :doc:`config`.
686 683 self.provider_name = provider.name
687 684
688 685 #: :class:`str` Provider type e.g.
689 686 # ``"authomatic.providers.oauth2.OAuth2"``.
690 687 self.provider_type = provider.get_type()
691 688
692 689 #: :class:`str` Provider type e.g.
693 690 # ``"authomatic.providers.oauth2.OAuth2"``.
694 691 self.provider_type_id = provider.type_id
695 692
696 693 #: :class:`str` Provider short name specified in the :doc:`config`.
697 694 self.provider_id = int(provider.id) if provider.id else None
698 695
699 696 #: :class:`class` Provider class.
700 697 self.provider_class = provider.__class__
701 698
702 699 #: :class:`str` Consumer key specified in the :doc:`config`.
703 700 self.consumer_key = provider.consumer_key
704 701
705 702 #: :class:`str` Consumer secret specified in the :doc:`config`.
706 703 self.consumer_secret = provider.consumer_secret
707 704
708 705 else:
709 706 self.provider_name = kwargs.get('provider_name', '')
710 707 self.provider_type = kwargs.get('provider_type', '')
711 708 self.provider_type_id = kwargs.get('provider_type_id')
712 709 self.provider_id = kwargs.get('provider_id')
713 710 self.provider_class = kwargs.get('provider_class')
714 711
715 712 self.consumer_key = kwargs.get('consumer_key', '')
716 713 self.consumer_secret = kwargs.get('consumer_secret', '')
717 714
718 715 @property
719 716 def expire_in(self):
720 717 """
721 718
722 719 """
723 720
724 721 return self._expire_in
725 722
726 723 @expire_in.setter
727 724 def expire_in(self, value):
728 725 """
729 726 Computes :attr:`.expiration_time` when the value is set.
730 727 """
731 728
732 729 # pylint:disable=attribute-defined-outside-init
733 730 if value:
734 731 self._expiration_time = int(time.time()) + int(value)
735 732 self._expire_in = value
736 733
737 734 @property
738 735 def expiration_time(self):
739 736 return self._expiration_time
740 737
741 738 @expiration_time.setter
742 739 def expiration_time(self, value):
743 740
744 741 # pylint:disable=attribute-defined-outside-init
745 742 self._expiration_time = int(value)
746 743 self._expire_in = self._expiration_time - int(time.time())
747 744
748 745 @property
749 746 def expiration_date(self):
750 747 """
751 748 Expiration date as :class:`datetime.datetime` or ``None`` if
752 749 credentials never expire.
753 750 """
754 751
755 752 if self.expire_in < 0:
756 753 return None
757 754 else:
758 755 return datetime.datetime.fromtimestamp(self.expiration_time)
759 756
760 757 @property
761 758 def valid(self):
762 759 """
763 760 ``True`` if credentials are valid, ``False`` if expired.
764 761 """
765 762
766 763 if self.expiration_time:
767 764 return self.expiration_time > int(time.time())
768 765 else:
769 766 return True
770 767
771 768 def expire_soon(self, seconds):
772 769 """
773 770 Returns ``True`` if credentials expire sooner than specified.
774 771
775 772 :param int seconds:
776 773 Number of seconds.
777 774
778 775 :returns:
779 776 ``True`` if credentials expire sooner than specified,
780 777 else ``False``.
781 778
782 779 """
783 780
784 781 if self.expiration_time:
785 782 return self.expiration_time < int(time.time()) + int(seconds)
786 783 else:
787 784 return False
788 785
789 786 def refresh(self, force=False, soon=86400):
790 787 """
791 788 Refreshes the credentials only if the **provider** supports it and if
792 789 it will expire in less than one day. It does nothing in other cases.
793 790
794 791 .. note::
795 792
796 793 The credentials will be refreshed only if it gives sense
797 794 i.e. only |oauth2|_ has the notion of credentials
798 795 *refreshment/extension*.
799 796 And there are also differences across providers e.g. Google
800 797 supports refreshment only if there is a ``refresh_token`` in
801 798 the credentials and that in turn is present only if the
802 799 ``access_type`` parameter was set to ``offline`` in the
803 800 **user authorization request**.
804 801
805 802 :param bool force:
806 803 If ``True`` the credentials will be refreshed even if they
807 804 won't expire soon.
808 805
809 806 :param int soon:
810 807 Number of seconds specifying what means *soon*.
811 808
812 809 """
813 810
814 811 if hasattr(self.provider_class, 'refresh_credentials'):
815 812 if force or self.expire_soon(soon):
816 813 logging.info('PROVIDER NAME: {0}'.format(self.provider_name))
817 814 return self.provider_class(
818 815 self, None, self.provider_name).refresh_credentials(self)
819 816
820 817 def async_refresh(self, *args, **kwargs):
821 818 """
822 819 Same as :meth:`.refresh` but runs asynchronously in a separate thread.
823 820
824 821 .. warning::
825 822
826 823 |async|
827 824
828 825 :returns:
829 826 :class:`.Future` instance representing the separate thread.
830 827
831 828 """
832 829
833 830 return Future(self.refresh, *args, **kwargs)
834 831
835 832 def provider_type_class(self):
836 833 """
837 834 Returns the :doc:`provider <providers>` class specified in the
838 835 :doc:`config`.
839 836
840 837 :returns:
841 838 :class:`authomatic.providers.BaseProvider` subclass.
842 839
843 840 """
844 841
845 842 return resolve_provider_class(self.provider_type)
846 843
847 844 def serialize(self):
848 845 """
849 846 Converts the credentials to a percent encoded string to be stored for
850 847 later use.
851 848
852 849 :returns:
853 850 :class:`string`
854 851
855 852 """
856 853
857 854 if self.provider_id is None:
858 855 raise ConfigError(
859 856 'To serialize credentials you need to specify a '
860 857 'unique integer under the "id" key in the config '
861 858 'for each provider!')
862 859
863 860 # Get the provider type specific items.
864 861 rest = self.provider_type_class().to_tuple(self)
865 862
866 863 # Provider ID and provider type ID are always the first two items.
867 864 result = (self.provider_id, self.provider_type_id) + rest
868 865
869 866 # Make sure that all items are strings.
870 867 stringified = [str(i) for i in result]
871 868
872 869 # Concatenate by newline.
873 870 concatenated = '\n'.join(stringified)
874 871
875 872 # Percent encode.
876 873 return parse.quote(concatenated, '')
877 874
878 875 @classmethod
879 876 def deserialize(cls, config, credentials):
880 877 """
881 878 A *class method* which reconstructs credentials created by
882 879 :meth:`serialize`. You can also pass it a :class:`.Credentials`
883 880 instance.
884 881
885 882 :param dict config:
886 883 The same :doc:`config` used in the :func:`.login` to get the
887 884 credentials.
888 885 :param str credentials:
889 886 :class:`string` The serialized credentials or
890 887 :class:`.Credentials` instance.
891 888
892 889 :returns:
893 890 :class:`.Credentials`
894 891
895 892 """
896 893
897 894 # Accept both serialized and normal.
898 895 if isinstance(credentials, Credentials):
899 896 return credentials
900 897
901 898 decoded = parse.unquote(credentials)
902 899
903 900 split = decoded.split('\n')
904 901
905 902 # We need the provider ID to move forward.
906 903 if split[0] is None:
907 904 raise CredentialsError(
908 905 'To deserialize credentials you need to specify a unique '
909 906 'integer under the "id" key in the config for each provider!')
910 907
911 908 # Get provider config by short name.
912 909 provider_name = id_to_name(config, int(split[0]))
913 910 cfg = config.get(provider_name)
914 911
915 912 # Get the provider class.
916 913 ProviderClass = resolve_provider_class(cfg.get('class_'))
917 914
918 915 deserialized = Credentials(config)
919 916
920 917 deserialized.provider_id = provider_id
921 918 deserialized.provider_type = ProviderClass.get_type()
922 919 deserialized.provider_type_id = split[1]
923 920 deserialized.provider_class = ProviderClass
924 921 deserialized.provider_name = provider_name
925 922 deserialized.provider_class = ProviderClass
926 923
927 924 # Add provider type specific properties.
928 925 return ProviderClass.reconstruct(split[2:], deserialized, cfg)
929 926
930 927
931 928 class LoginResult(ReprMixin):
932 929 """
933 930 Result of the :func:`authomatic.login` function.
934 931 """
935 932
936 933 def __init__(self, provider):
937 934 #: A :doc:`provider <providers>` instance.
938 935 self.provider = provider
939 936
940 937 #: An instance of the :exc:`authomatic.exceptions.BaseError` subclass.
941 938 self.error = None
942 939
943 940 def popup_js(self, callback_name=None, indent=None,
944 941 custom=None, stay_open=False):
945 942 """
946 943 Returns JavaScript that:
947 944
948 945 #. Triggers the ``options.onLoginComplete(result, closer)``
949 946 handler set with the :ref:`authomatic.setup() <js_setup>`
950 947 function of :ref:`javascript.js <js>`.
951 948 #. Calls the JavasScript callback specified by :data:`callback_name`
952 949 on the opener of the *login handler popup* and passes it the
953 950 *login result* JSON object as first argument and the `closer`
954 951 function which you should call in your callback to close the popup.
955 952
956 953 :param str callback_name:
957 954 The name of the javascript callback e.g ``foo.bar.loginCallback``
958 955 will result in ``window.opener.foo.bar.loginCallback(result);``
959 956 in the HTML.
960 957
961 958 :param int indent:
962 959 The number of spaces to indent the JSON result object.
963 960 If ``0`` or negative, only newlines are added.
964 961 If ``None``, no newlines are added.
965 962
966 963 :param custom:
967 964 Any JSON serializable object that will be passed to the
968 965 ``result.custom`` attribute.
969 966
970 967 :param str stay_open:
971 968 If ``True``, the popup will stay open.
972 969
973 970 :returns:
974 971 :class:`str` with JavaScript.
975 972
976 973 """
977 974
978 975 custom_callback = """
979 976 try {{ window.opener.{cb}(result, closer); }} catch(e) {{}}
980 977 """.format(cb=callback_name) if callback_name else ''
981 978
982 979 # TODO: Move the window.close() to the opener
983 980 return """
984 981 (function(){{
985 982
986 983 closer = function(){{
987 984 window.close();
988 985 }};
989 986
990 987 var result = {result};
991 988 result.custom = {custom};
992 989
993 990 {custom_callback}
994 991
995 992 try {{
996 993 window.opener.authomatic.loginComplete(result, closer);
997 994 }} catch(e) {{}}
998 995
999 996 }})();
1000 997
1001 998 """.format(result=self.to_json(indent),
1002 999 custom=json.dumps(custom),
1003 1000 custom_callback=custom_callback,
1004 1001 stay_open='// ' if stay_open else '')
1005 1002
1006 1003 def popup_html(self, callback_name=None, indent=None,
1007 1004 title='Login | {0}', custom=None, stay_open=False):
1008 1005 """
1009 1006 Returns a HTML with JavaScript that:
1010 1007
1011 1008 #. Triggers the ``options.onLoginComplete(result, closer)`` handler
1012 1009 set with the :ref:`authomatic.setup() <js_setup>` function of
1013 1010 :ref:`javascript.js <js>`.
1014 1011 #. Calls the JavasScript callback specified by :data:`callback_name`
1015 1012 on the opener of the *login handler popup* and passes it the
1016 1013 *login result* JSON object as first argument and the `closer`
1017 1014 function which you should call in your callback to close the popup.
1018 1015
1019 1016 :param str callback_name:
1020 1017 The name of the javascript callback e.g ``foo.bar.loginCallback``
1021 1018 will result in ``window.opener.foo.bar.loginCallback(result);``
1022 1019 in the HTML.
1023 1020
1024 1021 :param int indent:
1025 1022 The number of spaces to indent the JSON result object.
1026 1023 If ``0`` or negative, only newlines are added.
1027 1024 If ``None``, no newlines are added.
1028 1025
1029 1026 :param str title:
1030 1027 The text of the HTML title. You can use ``{0}`` tag inside,
1031 1028 which will be replaced by the provider name.
1032 1029
1033 1030 :param custom:
1034 1031 Any JSON serializable object that will be passed to the
1035 1032 ``result.custom`` attribute.
1036 1033
1037 1034 :param str stay_open:
1038 1035 If ``True``, the popup will stay open.
1039 1036
1040 1037 :returns:
1041 1038 :class:`str` with HTML.
1042 1039
1043 1040 """
1044 1041
1045 1042 return """
1046 1043 <!DOCTYPE html>
1047 1044 <html>
1048 1045 <head><title>{title}</title></head>
1049 1046 <body>
1050 1047 <script type="text/javascript">
1051 1048 {js}
1052 1049 </script>
1053 1050 </body>
1054 1051 </html>
1055 1052 """.format(
1056 1053 title=title.format(self.provider.name if self.provider else ''),
1057 1054 js=self.popup_js(callback_name, indent, custom, stay_open)
1058 1055 )
1059 1056
1060 1057 @property
1061 1058 def user(self):
1062 1059 """
1063 1060 A :class:`.User` instance.
1064 1061 """
1065 1062
1066 1063 return self.provider.user if self.provider else None
1067 1064
1068 1065 def to_dict(self):
1069 1066 return dict(provider=self.provider, user=self.user, error=self.error)
1070 1067
1071 1068 def to_json(self, indent=4):
1072 1069 return json.dumps(self, default=lambda obj: obj.to_dict(
1073 1070 ) if hasattr(obj, 'to_dict') else '', indent=indent)
1074 1071
1075 1072
1076 1073 class Response(ReprMixin):
1077 1074 """
1078 1075 Wraps :class:`httplib.HTTPResponse` and adds.
1079 1076
1080 1077 :attr:`.content` and :attr:`.data` attributes.
1081 1078
1082 1079 """
1083 1080
1084 1081 def __init__(self, httplib_response, content_parser=None):
1085 1082 """
1086 1083 :param httplib_response:
1087 1084 The wrapped :class:`httplib.HTTPResponse` instance.
1088 1085
1089 1086 :param function content_parser:
1090 1087 Callable which accepts :attr:`.content` as argument,
1091 1088 parses it and returns the parsed data as :class:`dict`.
1092 1089 """
1093 1090
1094 1091 self.httplib_response = httplib_response
1095 1092 self.content_parser = content_parser or json_qs_parser
1096 1093 self._data = None
1097 1094 self._content = None
1098 1095
1099 1096 #: Same as :attr:`httplib.HTTPResponse.msg`.
1100 1097 self.msg = httplib_response.msg
1101 1098 #: Same as :attr:`httplib.HTTPResponse.version`.
1102 1099 self.version = httplib_response.version
1103 1100 #: Same as :attr:`httplib.HTTPResponse.status`.
1104 1101 self.status = httplib_response.status
1105 1102 #: Same as :attr:`httplib.HTTPResponse.reason`.
1106 1103 self.reason = httplib_response.reason
1107 1104
1108 1105 def read(self, amt=None):
1109 1106 """
1110 1107 Same as :meth:`httplib.HTTPResponse.read`.
1111 1108
1112 1109 :param amt:
1113 1110
1114 1111 """
1115 1112
1116 1113 return self.httplib_response.read(amt)
1117 1114
1118 1115 def getheader(self, name, default=None):
1119 1116 """
1120 1117 Same as :meth:`httplib.HTTPResponse.getheader`.
1121 1118
1122 1119 :param name:
1123 1120 :param default:
1124 1121
1125 1122 """
1126 1123
1127 1124 return self.httplib_response.getheader(name, default)
1128 1125
1129 1126 def fileno(self):
1130 1127 """
1131 1128 Same as :meth:`httplib.HTTPResponse.fileno`.
1132 1129 """
1133 1130 return self.httplib_response.fileno()
1134 1131
1135 1132 def getheaders(self):
1136 1133 """
1137 1134 Same as :meth:`httplib.HTTPResponse.getheaders`.
1138 1135 """
1139 1136 return self.httplib_response.getheaders()
1140 1137
1141 1138 @staticmethod
1142 1139 def is_binary_string(content):
1143 1140 """
1144 1141 Return true if string is binary data.
1145 1142 """
1146 1143
1147 1144 textchars = (bytearray([7, 8, 9, 10, 12, 13, 27]) +
1148 1145 bytearray(range(0x20, 0x100)))
1149 1146 return bool(content.translate(None, textchars))
1150 1147
1151 1148 @property
1152 1149 def content(self):
1153 1150 """
1154 1151 The whole response content.
1155 1152 """
1156 1153
1157 1154 if not self._content:
1158 1155 content = self.httplib_response.read()
1159 1156 if self.is_binary_string(content):
1160 1157 self._content = content
1161 1158 else:
1162 1159 self._content = content.decode('utf-8')
1163 1160 return self._content
1164 1161
1165 1162 @property
1166 1163 def data(self):
1167 1164 """
1168 1165 A :class:`dict` of data parsed from :attr:`.content`.
1169 1166 """
1170 1167
1171 1168 if not self._data:
1172 1169 self._data = self.content_parser(self.content)
1173 1170 return self._data
1174 1171
1175 1172
1176 1173 class UserInfoResponse(Response):
1177 1174 """
1178 1175 Inherits from :class:`.Response`, adds :attr:`~UserInfoResponse.user`
1179 1176 attribute.
1180 1177 """
1181 1178
1182 1179 def __init__(self, user, *args, **kwargs):
1183 1180 super(UserInfoResponse, self).__init__(*args, **kwargs)
1184 1181
1185 1182 #: :class:`.User` instance.
1186 1183 self.user = user
1187 1184
1188 1185
1189 1186 class RequestElements(tuple):
1190 1187 """
1191 1188 A tuple of ``(url, method, params, headers, body)`` request elements.
1192 1189
1193 1190 With some additional properties.
1194 1191
1195 1192 """
1196 1193
1197 1194 def __new__(cls, url, method, params, headers, body):
1198 1195 return tuple.__new__(cls, (url, method, params, headers, body))
1199 1196
1200 1197 @property
1201 1198 def url(self):
1202 1199 """
1203 1200 Request URL.
1204 1201 """
1205 1202
1206 1203 return self[0]
1207 1204
1208 1205 @property
1209 1206 def method(self):
1210 1207 """
1211 1208 HTTP method of the request.
1212 1209 """
1213 1210
1214 1211 return self[1]
1215 1212
1216 1213 @property
1217 1214 def params(self):
1218 1215 """
1219 1216 Dictionary of request parameters.
1220 1217 """
1221 1218
1222 1219 return self[2]
1223 1220
1224 1221 @property
1225 1222 def headers(self):
1226 1223 """
1227 1224 Dictionary of request headers.
1228 1225 """
1229 1226
1230 1227 return self[3]
1231 1228
1232 1229 @property
1233 1230 def body(self):
1234 1231 """
1235 1232 :class:`str` Body of ``POST``, ``PUT`` and ``PATCH`` requests.
1236 1233 """
1237 1234
1238 1235 return self[4]
1239 1236
1240 1237 @property
1241 1238 def query_string(self):
1242 1239 """
1243 1240 Query string of the request.
1244 1241 """
1245 1242
1246 1243 return parse.urlencode(self.params)
1247 1244
1248 1245 @property
1249 1246 def full_url(self):
1250 1247 """
1251 1248 URL with query string.
1252 1249 """
1253 1250
1254 1251 return self.url + '?' + self.query_string
1255 1252
1256 1253 def to_json(self):
1257 1254 return json.dumps(dict(url=self.url,
1258 1255 method=self.method,
1259 1256 params=self.params,
1260 1257 headers=self.headers,
1261 1258 body=self.body))
1262 1259
1263 1260
1264 1261 class Authomatic(object):
1265 1262 def __init__(
1266 1263 self, config, secret, session_max_age=600, secure_cookie=False,
1267 1264 session=None, session_save_method=None, report_errors=True,
1268 1265 debug=False, logging_level=logging.INFO, prefix='authomatic',
1269 1266 logger=None
1270 1267 ):
1271 1268 """
1272 1269 Encapsulates all the functionality of this package.
1273 1270
1274 1271 :param dict config:
1275 1272 :doc:`config`
1276 1273
1277 1274 :param str secret:
1278 1275 A secret string that will be used as the key for signing
1279 1276 :class:`.Session` cookie and as a salt by *CSRF* token generation.
1280 1277
1281 1278 :param session_max_age:
1282 1279 Maximum allowed age of :class:`.Session` cookie nonce in seconds.
1283 1280
1284 1281 :param bool secure_cookie:
1285 1282 If ``True`` the :class:`.Session` cookie will be saved wit
1286 1283 ``Secure`` attribute.
1287 1284
1288 1285 :param session:
1289 1286 Custom dictionary-like session implementation.
1290 1287
1291 1288 :param callable session_save_method:
1292 1289 A method of the supplied session or any mechanism that saves the
1293 1290 session data and cookie.
1294 1291
1295 1292 :param bool report_errors:
1296 1293 If ``True`` exceptions encountered during the **login procedure**
1297 1294 will be caught and reported in the :attr:`.LoginResult.error`
1298 1295 attribute.
1299 1296 Default is ``True``.
1300 1297
1301 1298 :param bool debug:
1302 1299 If ``True`` traceback of exceptions will be written to response.
1303 1300 Default is ``False``.
1304 1301
1305 1302 :param int logging_level:
1306 1303 The logging level threshold for the default logger as specified in
1307 1304 the standard Python
1308 1305 `logging library <http://docs.python.org/2/library/logging.html>`_.
1309 1306 This setting is ignored when :data:`logger` is set.
1310 1307 Default is ``logging.INFO``.
1311 1308
1312 1309 :param str prefix:
1313 1310 Prefix used as the :class:`.Session` cookie name.
1314 1311
1315 1312 :param logger:
1316 1313 A :class:`logging.logger` instance.
1317 1314
1318 1315 """
1319 1316
1320 1317 self.config = config
1321 1318 self.secret = secret
1322 1319 self.session_max_age = session_max_age
1323 1320 self.secure_cookie = secure_cookie
1324 1321 self.session = session
1325 1322 self.session_save_method = session_save_method
1326 1323 self.report_errors = report_errors
1327 1324 self.debug = debug
1328 1325 self.logging_level = logging_level
1329 1326 self.prefix = prefix
1330 1327 self._logger = logger or logging.getLogger(str(id(self)))
1331 1328
1332 1329 # Set logging level.
1333 1330 if logger is None:
1334 1331 self._logger.setLevel(logging_level)
1335 1332
1336 1333 def login(self, adapter, provider_name, callback=None,
1337 1334 session=None, session_saver=None, **kwargs):
1338 1335 """
1339 1336 If :data:`provider_name` specified, launches the login procedure for
1340 1337 corresponding :doc:`provider </reference/providers>` and returns
1341 1338 :class:`.LoginResult`.
1342 1339
1343 1340 If :data:`provider_name` is empty, acts like
1344 1341 :meth:`.Authomatic.backend`.
1345 1342
1346 1343 .. warning::
1347 1344
1348 1345 The method redirects the **user** to the **provider** which in
1349 1346 turn redirects **him/her** back to the *request handler* where
1350 1347 it has been called.
1351 1348
1352 1349 :param str provider_name:
1353 1350 Name of the provider as specified in the keys of the :doc:`config`.
1354 1351
1355 1352 :param callable callback:
1356 1353 If specified the method will call the callback with
1357 1354 :class:`.LoginResult` passed as argument and will return nothing.
1358 1355
1359 1356 :param bool report_errors:
1360 1357
1361 1358 .. note::
1362 1359
1363 1360 Accepts additional keyword arguments that will be passed to
1364 1361 :doc:`provider <providers>` constructor.
1365 1362
1366 1363 :returns:
1367 1364 :class:`.LoginResult`
1368 1365
1369 1366 """
1370 1367
1371 1368 if provider_name:
1372 1369 # retrieve required settings for current provider and raise
1373 1370 # exceptions if missing
1374 1371 provider_settings = self.config.get(provider_name)
1375 1372 if not provider_settings:
1376 1373 raise ConfigError('Provider name "{0}" not specified!'
1377 1374 .format(provider_name))
1378 1375
1379 1376 if not (session is None or session_saver is None):
1380 1377 session = session
1381 1378 session_saver = session_saver
1382 1379 else:
1383 1380 session = Session(adapter=adapter,
1384 1381 secret=self.secret,
1385 1382 max_age=self.session_max_age,
1386 1383 name=self.prefix,
1387 1384 secure=self.secure_cookie)
1388 1385
1389 1386 session_saver = session.save
1390 1387
1391 1388 # Resolve provider class.
1392 1389 class_ = provider_settings.get('class_')
1393 1390 if not class_:
1394 1391 raise ConfigError(
1395 1392 'The "class_" key not specified in the config'
1396 1393 ' for provider {0}!'.format(provider_name))
1397 1394 ProviderClass = resolve_provider_class(class_)
1398 1395
1399 1396 # FIXME: Find a nicer solution
1400 1397 ProviderClass._logger = self._logger
1401 1398
1402 1399 # instantiate provider class
1403 1400 provider = ProviderClass(self,
1404 1401 adapter=adapter,
1405 1402 provider_name=provider_name,
1406 1403 callback=callback,
1407 1404 session=session,
1408 1405 session_saver=session_saver,
1409 1406 **kwargs)
1410 1407
1411 1408 # return login result
1412 1409 return provider.login()
1413 1410
1414 1411 else:
1415 1412 # Act like backend.
1416 1413 self.backend(adapter)
1417 1414
1418 1415 def credentials(self, credentials):
1419 1416 """
1420 1417 Deserializes credentials.
1421 1418
1422 1419 :param credentials:
1423 1420 Credentials serialized with :meth:`.Credentials.serialize` or
1424 1421 :class:`.Credentials` instance.
1425 1422
1426 1423 :returns:
1427 1424 :class:`.Credentials`
1428 1425
1429 1426 """
1430 1427
1431 1428 return Credentials.deserialize(self.config, credentials)
1432 1429
1433 1430 def access(self, credentials, url, params=None, method='GET',
1434 1431 headers=None, body='', max_redirects=5, content_parser=None):
1435 1432 """
1436 1433 Accesses **protected resource** on behalf of the **user**.
1437 1434
1438 1435 :param credentials:
1439 1436 The **user's** :class:`.Credentials` (serialized or normal).
1440 1437
1441 1438 :param str url:
1442 1439 The **protected resource** URL.
1443 1440
1444 1441 :param str method:
1445 1442 HTTP method of the request.
1446 1443
1447 1444 :param dict headers:
1448 1445 HTTP headers of the request.
1449 1446
1450 1447 :param str body:
1451 1448 Body of ``POST``, ``PUT`` and ``PATCH`` requests.
1452 1449
1453 1450 :param int max_redirects:
1454 1451 Maximum number of HTTP redirects to follow.
1455 1452
1456 1453 :param function content_parser:
1457 1454 A function to be used to parse the :attr:`.Response.data`
1458 1455 from :attr:`.Response.content`.
1459 1456
1460 1457 :returns:
1461 1458 :class:`.Response`
1462 1459
1463 1460 """
1464 1461
1465 1462 # Deserialize credentials.
1466 1463 credentials = Credentials.deserialize(self.config, credentials)
1467 1464
1468 1465 # Resolve provider class.
1469 1466 ProviderClass = credentials.provider_class
1470 1467 logging.info('ACCESS HEADERS: {0}'.format(headers))
1471 1468 # Access resource and return response.
1472 1469
1473 1470 provider = ProviderClass(
1474 1471 self, adapter=None, provider_name=credentials.provider_name)
1475 1472 provider.credentials = credentials
1476 1473
1477 1474 return provider.access(url=url,
1478 1475 params=params,
1479 1476 method=method,
1480 1477 headers=headers,
1481 1478 body=body,
1482 1479 max_redirects=max_redirects,
1483 1480 content_parser=content_parser)
1484 1481
1485 1482 def async_access(self, *args, **kwargs):
1486 1483 """
1487 1484 Same as :meth:`.Authomatic.access` but runs asynchronously in a
1488 1485 separate thread.
1489 1486
1490 1487 .. warning::
1491 1488
1492 1489 |async|
1493 1490
1494 1491 :returns:
1495 1492 :class:`.Future` instance representing the separate thread.
1496 1493
1497 1494 """
1498 1495
1499 1496 return Future(self.access, *args, **kwargs)
1500 1497
1501 1498 def request_elements(
1502 1499 self, credentials=None, url=None, method='GET', params=None,
1503 1500 headers=None, body='', json_input=None, return_json=False
1504 1501 ):
1505 1502 """
1506 1503 Creates request elements for accessing **protected resource of a
1507 1504 user**. Required arguments are :data:`credentials` and :data:`url`. You
1508 1505 can pass :data:`credentials`, :data:`url`, :data:`method`, and
1509 1506 :data:`params` as a JSON object.
1510 1507
1511 1508 :param credentials:
1512 1509 The **user's** credentials (can be serialized).
1513 1510
1514 1511 :param str url:
1515 1512 The url of the protected resource.
1516 1513
1517 1514 :param str method:
1518 1515 The HTTP method of the request.
1519 1516
1520 1517 :param dict params:
1521 1518 Dictionary of request parameters.
1522 1519
1523 1520 :param dict headers:
1524 1521 Dictionary of request headers.
1525 1522
1526 1523 :param str body:
1527 1524 Body of ``POST``, ``PUT`` and ``PATCH`` requests.
1528 1525
1529 1526 :param str json_input:
1530 1527 you can pass :data:`credentials`, :data:`url`, :data:`method`,
1531 1528 :data:`params` and :data:`headers` in a JSON object.
1532 1529 Values from arguments will be used for missing properties.
1533 1530
1534 1531 ::
1535 1532
1536 1533 {
1537 1534 "credentials": "###",
1538 1535 "url": "https://example.com/api",
1539 1536 "method": "POST",
1540 1537 "params": {
1541 1538 "foo": "bar"
1542 1539 },
1543 1540 "headers": {
1544 1541 "baz": "bing",
1545 1542 "Authorization": "Bearer ###"
1546 1543 },
1547 1544 "body": "Foo bar baz bing."
1548 1545 }
1549 1546
1550 1547 :param bool return_json:
1551 1548 if ``True`` the function returns a json object.
1552 1549
1553 1550 ::
1554 1551
1555 1552 {
1556 1553 "url": "https://example.com/api",
1557 1554 "method": "POST",
1558 1555 "params": {
1559 1556 "access_token": "###",
1560 1557 "foo": "bar"
1561 1558 },
1562 1559 "headers": {
1563 1560 "baz": "bing",
1564 1561 "Authorization": "Bearer ###"
1565 1562 },
1566 1563 "body": "Foo bar baz bing."
1567 1564 }
1568 1565
1569 1566 :returns:
1570 1567 :class:`.RequestElements` or JSON string.
1571 1568
1572 1569 """
1573 1570
1574 1571 # Parse values from JSON
1575 1572 if json_input:
1576 1573 parsed_input = json.loads(json_input)
1577 1574
1578 1575 credentials = parsed_input.get('credentials', credentials)
1579 1576 url = parsed_input.get('url', url)
1580 1577 method = parsed_input.get('method', method)
1581 1578 params = parsed_input.get('params', params)
1582 1579 headers = parsed_input.get('headers', headers)
1583 1580 body = parsed_input.get('body', body)
1584 1581
1585 1582 if not credentials and url:
1586 1583 raise RequestElementsError(
1587 1584 'To create request elements, you must provide credentials '
1588 1585 'and URL either as keyword arguments or in the JSON object!')
1589 1586
1590 1587 # Get the provider class
1591 1588 credentials = Credentials.deserialize(self.config, credentials)
1592 1589 ProviderClass = credentials.provider_class
1593 1590
1594 1591 # Create request elements
1595 1592 request_elements = ProviderClass.create_request_elements(
1596 1593 ProviderClass.PROTECTED_RESOURCE_REQUEST_TYPE,
1597 1594 credentials=credentials,
1598 1595 url=url,
1599 1596 method=method,
1600 1597 params=params,
1601 1598 headers=headers,
1602 1599 body=body)
1603 1600
1604 1601 if return_json:
1605 1602 return request_elements.to_json()
1606 1603
1607 1604 else:
1608 1605 return request_elements
1609 1606
1610 1607 def backend(self, adapter):
1611 1608 """
1612 1609 Converts a *request handler* to a JSON backend which you can use with
1613 1610 :ref:`authomatic.js <js>`.
1614 1611
1615 1612 Just call it inside a *request handler* like this:
1616 1613
1617 1614 ::
1618 1615
1619 1616 class JSONHandler(webapp2.RequestHandler):
1620 1617 def get(self):
1621 1618 authomatic.backend(Webapp2Adapter(self))
1622 1619
1623 1620 :param adapter:
1624 1621 The only argument is an :doc:`adapter <adapters>`.
1625 1622
1626 1623 The *request handler* will now accept these request parameters:
1627 1624
1628 1625 :param str type:
1629 1626 Type of the request. Either ``auto``, ``fetch`` or ``elements``.
1630 1627 Default is ``auto``.
1631 1628
1632 1629 :param str credentials:
1633 1630 Serialized :class:`.Credentials`.
1634 1631
1635 1632 :param str url:
1636 1633 URL of the **protected resource** request.
1637 1634
1638 1635 :param str method:
1639 1636 HTTP method of the **protected resource** request.
1640 1637
1641 1638 :param str body:
1642 1639 HTTP body of the **protected resource** request.
1643 1640
1644 1641 :param JSON params:
1645 1642 HTTP params of the **protected resource** request as a JSON object.
1646 1643
1647 1644 :param JSON headers:
1648 1645 HTTP headers of the **protected resource** request as a
1649 1646 JSON object.
1650 1647
1651 1648 :param JSON json:
1652 1649 You can pass all of the aforementioned params except ``type``
1653 1650 in a JSON object.
1654 1651
1655 1652 .. code-block:: javascript
1656 1653
1657 1654 {
1658 1655 "credentials": "######",
1659 1656 "url": "https://example.com",
1660 1657 "method": "POST",
1661 1658 "params": {"foo": "bar"},
1662 1659 "headers": {"baz": "bing"},
1663 1660 "body": "the body of the request"
1664 1661 }
1665 1662
1666 1663 Depending on the ``type`` param, the handler will either write
1667 1664 a JSON object with *request elements* to the response,
1668 1665 and add an ``Authomatic-Response-To: elements`` response header, ...
1669 1666
1670 1667 .. code-block:: javascript
1671 1668
1672 1669 {
1673 1670 "url": "https://example.com/api",
1674 1671 "method": "POST",
1675 1672 "params": {
1676 1673 "access_token": "###",
1677 1674 "foo": "bar"
1678 1675 },
1679 1676 "headers": {
1680 1677 "baz": "bing",
1681 1678 "Authorization": "Bearer ###"
1682 1679 }
1683 1680 }
1684 1681
1685 1682 ... or make a fetch to the **protected resource** and forward
1686 1683 it's response content, status and headers with an additional
1687 1684 ``Authomatic-Response-To: fetch`` header to the response.
1688 1685
1689 1686 .. warning::
1690 1687
1691 1688 The backend will not work if you write anything to the
1692 1689 response in the handler!
1693 1690
1694 1691 """
1695 1692
1696 1693 AUTHOMATIC_HEADER = 'Authomatic-Response-To'
1697 1694
1698 1695 # Collect request params
1699 1696 request_type = adapter.params.get('type', 'auto')
1700 1697 json_input = adapter.params.get('json')
1701 1698 credentials = adapter.params.get('credentials')
1702 1699 url = adapter.params.get('url')
1703 1700 method = adapter.params.get('method', 'GET')
1704 1701 body = adapter.params.get('body', '')
1705 1702
1706 1703 params = adapter.params.get('params')
1707 1704 params = json.loads(params) if params else {}
1708 1705
1709 1706 headers = adapter.params.get('headers')
1710 1707 headers = json.loads(headers) if headers else {}
1711 1708
1712 1709 ProviderClass = Credentials.deserialize(
1713 1710 self.config, credentials).provider_class
1714 1711
1715 1712 if request_type == 'auto':
1716 1713 # If there is a "callback" param, it's a JSONP request.
1717 1714 jsonp = params.get('callback')
1718 1715
1719 1716 # JSONP is possible only with GET method.
1720 1717 if ProviderClass.supports_jsonp and method is 'GET':
1721 1718 request_type = 'elements'
1722 1719 else:
1723 1720 # Remove the JSONP callback
1724 1721 if jsonp:
1725 1722 params.pop('callback')
1726 1723 request_type = 'fetch'
1727 1724
1728 1725 if request_type == 'fetch':
1729 1726 # Access protected resource
1730 1727 response = self.access(
1731 1728 credentials, url, params, method, headers, body)
1732 1729 result = response.content
1733 1730
1734 1731 # Forward status
1735 1732 adapter.status = str(response.status) + ' ' + str(response.reason)
1736 1733
1737 1734 # Forward headers
1738 1735 for k, v in response.getheaders():
1739 1736 logging.info(' {0}: {1}'.format(k, v))
1740 1737 adapter.set_header(k, v)
1741 1738
1742 1739 elif request_type == 'elements':
1743 1740 # Create request elements
1744 1741 if json_input:
1745 1742 result = self.request_elements(
1746 1743 json_input=json_input, return_json=True)
1747 1744 else:
1748 1745 result = self.request_elements(credentials=credentials,
1749 1746 url=url,
1750 1747 method=method,
1751 1748 params=params,
1752 1749 headers=headers,
1753 1750 body=body,
1754 1751 return_json=True)
1755 1752
1756 1753 adapter.set_header('Content-Type', 'application/json')
1757 1754 else:
1758 1755 result = '{"error": "Bad Request!"}'
1759 1756
1760 1757 # Add the authomatic header
1761 1758 adapter.set_header(AUTHOMATIC_HEADER, request_type)
1762 1759
1763 1760 # Write result to response
1764 1761 adapter.write(result)
@@ -1,1272 +1,1272 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Set of diffing helpers, previously part of vcs
24 24 """
25 25
26 26 import os
27 27 import re
28 28 import bz2
29 29 import gzip
30 30 import time
31 31
32 32 import collections
33 33 import difflib
34 34 import logging
35 import cPickle as pickle
36 from itertools import tee, imap
35 import pickle
36 from itertools import tee
37 37
38 38 from rhodecode.lib.vcs.exceptions import VCSError
39 39 from rhodecode.lib.vcs.nodes import FileNode, SubModuleNode
40 40 from rhodecode.lib.utils2 import safe_unicode, safe_str
41 41
42 42 log = logging.getLogger(__name__)
43 43
44 44 # define max context, a file with more than this numbers of lines is unusable
45 45 # in browser anyway
46 46 MAX_CONTEXT = 20 * 1024
47 47 DEFAULT_CONTEXT = 3
48 48
49 49
50 50 def get_diff_context(request):
51 51 return MAX_CONTEXT if request.GET.get('fullcontext', '') == '1' else DEFAULT_CONTEXT
52 52
53 53
54 54 def get_diff_whitespace_flag(request):
55 55 return request.GET.get('ignorews', '') == '1'
56 56
57 57
58 58 class OPS(object):
59 59 ADD = 'A'
60 60 MOD = 'M'
61 61 DEL = 'D'
62 62
63 63
64 64 def get_gitdiff(filenode_old, filenode_new, ignore_whitespace=True, context=3):
65 65 """
66 66 Returns git style diff between given ``filenode_old`` and ``filenode_new``.
67 67
68 68 :param ignore_whitespace: ignore whitespaces in diff
69 69 """
70 70 # make sure we pass in default context
71 71 context = context or 3
72 72 # protect against IntOverflow when passing HUGE context
73 73 if context > MAX_CONTEXT:
74 74 context = MAX_CONTEXT
75 75
76 76 submodules = filter(lambda o: isinstance(o, SubModuleNode),
77 77 [filenode_new, filenode_old])
78 78 if submodules:
79 79 return ''
80 80
81 81 for filenode in (filenode_old, filenode_new):
82 82 if not isinstance(filenode, FileNode):
83 83 raise VCSError(
84 84 "Given object should be FileNode object, not %s"
85 85 % filenode.__class__)
86 86
87 87 repo = filenode_new.commit.repository
88 88 old_commit = filenode_old.commit or repo.EMPTY_COMMIT
89 89 new_commit = filenode_new.commit
90 90
91 91 vcs_gitdiff = repo.get_diff(
92 92 old_commit, new_commit, filenode_new.path,
93 93 ignore_whitespace, context, path1=filenode_old.path)
94 94 return vcs_gitdiff
95 95
96 96 NEW_FILENODE = 1
97 97 DEL_FILENODE = 2
98 98 MOD_FILENODE = 3
99 99 RENAMED_FILENODE = 4
100 100 COPIED_FILENODE = 5
101 101 CHMOD_FILENODE = 6
102 102 BIN_FILENODE = 7
103 103
104 104
105 105 class LimitedDiffContainer(object):
106 106
107 107 def __init__(self, diff_limit, cur_diff_size, diff):
108 108 self.diff = diff
109 109 self.diff_limit = diff_limit
110 110 self.cur_diff_size = cur_diff_size
111 111
112 112 def __getitem__(self, key):
113 113 return self.diff.__getitem__(key)
114 114
115 115 def __iter__(self):
116 116 for l in self.diff:
117 117 yield l
118 118
119 119
120 120 class Action(object):
121 121 """
122 122 Contains constants for the action value of the lines in a parsed diff.
123 123 """
124 124
125 125 ADD = 'add'
126 126 DELETE = 'del'
127 127 UNMODIFIED = 'unmod'
128 128
129 129 CONTEXT = 'context'
130 130 OLD_NO_NL = 'old-no-nl'
131 131 NEW_NO_NL = 'new-no-nl'
132 132
133 133
134 134 class DiffProcessor(object):
135 135 """
136 136 Give it a unified or git diff and it returns a list of the files that were
137 137 mentioned in the diff together with a dict of meta information that
138 138 can be used to render it in a HTML template.
139 139
140 140 .. note:: Unicode handling
141 141
142 142 The original diffs are a byte sequence and can contain filenames
143 143 in mixed encodings. This class generally returns `unicode` objects
144 144 since the result is intended for presentation to the user.
145 145
146 146 """
147 147 _chunk_re = re.compile(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(.*)')
148 148 _newline_marker = re.compile(r'^\\ No newline at end of file')
149 149
150 150 # used for inline highlighter word split
151 151 _token_re = re.compile(r'()(&gt;|&lt;|&amp;|\W+?)')
152 152
153 153 # collapse ranges of commits over given number
154 154 _collapse_commits_over = 5
155 155
156 156 def __init__(self, diff, format='gitdiff', diff_limit=None,
157 157 file_limit=None, show_full_diff=True):
158 158 """
159 159 :param diff: A `Diff` object representing a diff from a vcs backend
160 160 :param format: format of diff passed, `udiff` or `gitdiff`
161 161 :param diff_limit: define the size of diff that is considered "big"
162 162 based on that parameter cut off will be triggered, set to None
163 163 to show full diff
164 164 """
165 165 self._diff = diff
166 166 self._format = format
167 167 self.adds = 0
168 168 self.removes = 0
169 169 # calculate diff size
170 170 self.diff_limit = diff_limit
171 171 self.file_limit = file_limit
172 172 self.show_full_diff = show_full_diff
173 173 self.cur_diff_size = 0
174 174 self.parsed = False
175 175 self.parsed_diff = []
176 176
177 177 log.debug('Initialized DiffProcessor with %s mode', format)
178 178 if format == 'gitdiff':
179 179 self.differ = self._highlight_line_difflib
180 180 self._parser = self._parse_gitdiff
181 181 else:
182 182 self.differ = self._highlight_line_udiff
183 183 self._parser = self._new_parse_gitdiff
184 184
185 185 def _copy_iterator(self):
186 186 """
187 187 make a fresh copy of generator, we should not iterate thru
188 188 an original as it's needed for repeating operations on
189 189 this instance of DiffProcessor
190 190 """
191 191 self.__udiff, iterator_copy = tee(self.__udiff)
192 192 return iterator_copy
193 193
194 194 def _escaper(self, string):
195 195 """
196 196 Escaper for diff escapes special chars and checks the diff limit
197 197
198 198 :param string:
199 199 """
200 200 self.cur_diff_size += len(string)
201 201
202 202 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
203 203 raise DiffLimitExceeded('Diff Limit Exceeded')
204 204
205 205 return string \
206 206 .replace('&', '&amp;')\
207 207 .replace('<', '&lt;')\
208 208 .replace('>', '&gt;')
209 209
210 210 def _line_counter(self, l):
211 211 """
212 212 Checks each line and bumps total adds/removes for this diff
213 213
214 214 :param l:
215 215 """
216 216 if l.startswith('+') and not l.startswith('+++'):
217 217 self.adds += 1
218 218 elif l.startswith('-') and not l.startswith('---'):
219 219 self.removes += 1
220 220 return safe_unicode(l)
221 221
222 222 def _highlight_line_difflib(self, line, next_):
223 223 """
224 224 Highlight inline changes in both lines.
225 225 """
226 226
227 227 if line['action'] == Action.DELETE:
228 228 old, new = line, next_
229 229 else:
230 230 old, new = next_, line
231 231
232 232 oldwords = self._token_re.split(old['line'])
233 233 newwords = self._token_re.split(new['line'])
234 234 sequence = difflib.SequenceMatcher(None, oldwords, newwords)
235 235
236 236 oldfragments, newfragments = [], []
237 237 for tag, i1, i2, j1, j2 in sequence.get_opcodes():
238 238 oldfrag = ''.join(oldwords[i1:i2])
239 239 newfrag = ''.join(newwords[j1:j2])
240 240 if tag != 'equal':
241 241 if oldfrag:
242 242 oldfrag = '<del>%s</del>' % oldfrag
243 243 if newfrag:
244 244 newfrag = '<ins>%s</ins>' % newfrag
245 245 oldfragments.append(oldfrag)
246 246 newfragments.append(newfrag)
247 247
248 248 old['line'] = "".join(oldfragments)
249 249 new['line'] = "".join(newfragments)
250 250
251 251 def _highlight_line_udiff(self, line, next_):
252 252 """
253 253 Highlight inline changes in both lines.
254 254 """
255 255 start = 0
256 256 limit = min(len(line['line']), len(next_['line']))
257 257 while start < limit and line['line'][start] == next_['line'][start]:
258 258 start += 1
259 259 end = -1
260 260 limit -= start
261 261 while -end <= limit and line['line'][end] == next_['line'][end]:
262 262 end -= 1
263 263 end += 1
264 264 if start or end:
265 265 def do(l):
266 266 last = end + len(l['line'])
267 267 if l['action'] == Action.ADD:
268 268 tag = 'ins'
269 269 else:
270 270 tag = 'del'
271 271 l['line'] = '%s<%s>%s</%s>%s' % (
272 272 l['line'][:start],
273 273 tag,
274 274 l['line'][start:last],
275 275 tag,
276 276 l['line'][last:]
277 277 )
278 278 do(line)
279 279 do(next_)
280 280
281 281 def _clean_line(self, line, command):
282 282 if command in ['+', '-', ' ']:
283 283 # only modify the line if it's actually a diff thing
284 284 line = line[1:]
285 285 return line
286 286
287 287 def _parse_gitdiff(self, inline_diff=True):
288 288 _files = []
289 289 diff_container = lambda arg: arg
290 290
291 291 for chunk in self._diff.chunks():
292 292 head = chunk.header
293 293
294 294 diff = imap(self._escaper, self.diff_splitter(chunk.diff))
295 295 raw_diff = chunk.raw
296 296 limited_diff = False
297 297 exceeds_limit = False
298 298
299 299 op = None
300 300 stats = {
301 301 'added': 0,
302 302 'deleted': 0,
303 303 'binary': False,
304 304 'ops': {},
305 305 }
306 306
307 307 if head['deleted_file_mode']:
308 308 op = OPS.DEL
309 309 stats['binary'] = True
310 310 stats['ops'][DEL_FILENODE] = 'deleted file'
311 311
312 312 elif head['new_file_mode']:
313 313 op = OPS.ADD
314 314 stats['binary'] = True
315 315 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
316 316 else: # modify operation, can be copy, rename or chmod
317 317
318 318 # CHMOD
319 319 if head['new_mode'] and head['old_mode']:
320 320 op = OPS.MOD
321 321 stats['binary'] = True
322 322 stats['ops'][CHMOD_FILENODE] = (
323 323 'modified file chmod %s => %s' % (
324 324 head['old_mode'], head['new_mode']))
325 325 # RENAME
326 326 if head['rename_from'] != head['rename_to']:
327 327 op = OPS.MOD
328 328 stats['binary'] = True
329 329 stats['ops'][RENAMED_FILENODE] = (
330 330 'file renamed from %s to %s' % (
331 331 head['rename_from'], head['rename_to']))
332 332 # COPY
333 333 if head.get('copy_from') and head.get('copy_to'):
334 334 op = OPS.MOD
335 335 stats['binary'] = True
336 336 stats['ops'][COPIED_FILENODE] = (
337 337 'file copied from %s to %s' % (
338 338 head['copy_from'], head['copy_to']))
339 339
340 340 # If our new parsed headers didn't match anything fallback to
341 341 # old style detection
342 342 if op is None:
343 343 if not head['a_file'] and head['b_file']:
344 344 op = OPS.ADD
345 345 stats['binary'] = True
346 346 stats['ops'][NEW_FILENODE] = 'new file'
347 347
348 348 elif head['a_file'] and not head['b_file']:
349 349 op = OPS.DEL
350 350 stats['binary'] = True
351 351 stats['ops'][DEL_FILENODE] = 'deleted file'
352 352
353 353 # it's not ADD not DELETE
354 354 if op is None:
355 355 op = OPS.MOD
356 356 stats['binary'] = True
357 357 stats['ops'][MOD_FILENODE] = 'modified file'
358 358
359 359 # a real non-binary diff
360 360 if head['a_file'] or head['b_file']:
361 361 try:
362 362 raw_diff, chunks, _stats = self._parse_lines(diff)
363 363 stats['binary'] = False
364 364 stats['added'] = _stats[0]
365 365 stats['deleted'] = _stats[1]
366 366 # explicit mark that it's a modified file
367 367 if op == OPS.MOD:
368 368 stats['ops'][MOD_FILENODE] = 'modified file'
369 369 exceeds_limit = len(raw_diff) > self.file_limit
370 370
371 371 # changed from _escaper function so we validate size of
372 372 # each file instead of the whole diff
373 373 # diff will hide big files but still show small ones
374 374 # from my tests, big files are fairly safe to be parsed
375 375 # but the browser is the bottleneck
376 376 if not self.show_full_diff and exceeds_limit:
377 377 raise DiffLimitExceeded('File Limit Exceeded')
378 378
379 379 except DiffLimitExceeded:
380 380 diff_container = lambda _diff: \
381 381 LimitedDiffContainer(
382 382 self.diff_limit, self.cur_diff_size, _diff)
383 383
384 384 exceeds_limit = len(raw_diff) > self.file_limit
385 385 limited_diff = True
386 386 chunks = []
387 387
388 388 else: # GIT format binary patch, or possibly empty diff
389 389 if head['bin_patch']:
390 390 # we have operation already extracted, but we mark simply
391 391 # it's a diff we wont show for binary files
392 392 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
393 393 chunks = []
394 394
395 395 if chunks and not self.show_full_diff and op == OPS.DEL:
396 396 # if not full diff mode show deleted file contents
397 397 # TODO: anderson: if the view is not too big, there is no way
398 398 # to see the content of the file
399 399 chunks = []
400 400
401 401 chunks.insert(0, [{
402 402 'old_lineno': '',
403 403 'new_lineno': '',
404 404 'action': Action.CONTEXT,
405 405 'line': msg,
406 406 } for _op, msg in stats['ops'].iteritems()
407 407 if _op not in [MOD_FILENODE]])
408 408
409 409 _files.append({
410 410 'filename': safe_unicode(head['b_path']),
411 411 'old_revision': head['a_blob_id'],
412 412 'new_revision': head['b_blob_id'],
413 413 'chunks': chunks,
414 414 'raw_diff': safe_unicode(raw_diff),
415 415 'operation': op,
416 416 'stats': stats,
417 417 'exceeds_limit': exceeds_limit,
418 418 'is_limited_diff': limited_diff,
419 419 })
420 420
421 421 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
422 422 OPS.DEL: 2}.get(info['operation'])
423 423
424 424 if not inline_diff:
425 425 return diff_container(sorted(_files, key=sorter))
426 426
427 427 # highlight inline changes
428 428 for diff_data in _files:
429 429 for chunk in diff_data['chunks']:
430 430 lineiter = iter(chunk)
431 431 try:
432 432 while 1:
433 line = lineiter.next()
433 line = next(lineiter)
434 434 if line['action'] not in (
435 435 Action.UNMODIFIED, Action.CONTEXT):
436 nextline = lineiter.next()
436 nextline = next(lineiter)
437 437 if nextline['action'] in ['unmod', 'context'] or \
438 438 nextline['action'] == line['action']:
439 439 continue
440 440 self.differ(line, nextline)
441 441 except StopIteration:
442 442 pass
443 443
444 444 return diff_container(sorted(_files, key=sorter))
445 445
446 446 def _check_large_diff(self):
447 447 if self.diff_limit:
448 448 log.debug('Checking if diff exceeds current diff_limit of %s', self.diff_limit)
449 449 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
450 450 raise DiffLimitExceeded('Diff Limit `%s` Exceeded', self.diff_limit)
451 451
452 452 # FIXME: NEWDIFFS: dan: this replaces _parse_gitdiff
453 453 def _new_parse_gitdiff(self, inline_diff=True):
454 454 _files = []
455 455
456 456 # this can be overriden later to a LimitedDiffContainer type
457 457 diff_container = lambda arg: arg
458 458
459 459 for chunk in self._diff.chunks():
460 460 head = chunk.header
461 461 log.debug('parsing diff %r', head)
462 462
463 463 raw_diff = chunk.raw
464 464 limited_diff = False
465 465 exceeds_limit = False
466 466
467 467 op = None
468 468 stats = {
469 469 'added': 0,
470 470 'deleted': 0,
471 471 'binary': False,
472 472 'old_mode': None,
473 473 'new_mode': None,
474 474 'ops': {},
475 475 }
476 476 if head['old_mode']:
477 477 stats['old_mode'] = head['old_mode']
478 478 if head['new_mode']:
479 479 stats['new_mode'] = head['new_mode']
480 480 if head['b_mode']:
481 481 stats['new_mode'] = head['b_mode']
482 482
483 483 # delete file
484 484 if head['deleted_file_mode']:
485 485 op = OPS.DEL
486 486 stats['binary'] = True
487 487 stats['ops'][DEL_FILENODE] = 'deleted file'
488 488
489 489 # new file
490 490 elif head['new_file_mode']:
491 491 op = OPS.ADD
492 492 stats['binary'] = True
493 493 stats['old_mode'] = None
494 494 stats['new_mode'] = head['new_file_mode']
495 495 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
496 496
497 497 # modify operation, can be copy, rename or chmod
498 498 else:
499 499 # CHMOD
500 500 if head['new_mode'] and head['old_mode']:
501 501 op = OPS.MOD
502 502 stats['binary'] = True
503 503 stats['ops'][CHMOD_FILENODE] = (
504 504 'modified file chmod %s => %s' % (
505 505 head['old_mode'], head['new_mode']))
506 506
507 507 # RENAME
508 508 if head['rename_from'] != head['rename_to']:
509 509 op = OPS.MOD
510 510 stats['binary'] = True
511 511 stats['renamed'] = (head['rename_from'], head['rename_to'])
512 512 stats['ops'][RENAMED_FILENODE] = (
513 513 'file renamed from %s to %s' % (
514 514 head['rename_from'], head['rename_to']))
515 515 # COPY
516 516 if head.get('copy_from') and head.get('copy_to'):
517 517 op = OPS.MOD
518 518 stats['binary'] = True
519 519 stats['copied'] = (head['copy_from'], head['copy_to'])
520 520 stats['ops'][COPIED_FILENODE] = (
521 521 'file copied from %s to %s' % (
522 522 head['copy_from'], head['copy_to']))
523 523
524 524 # If our new parsed headers didn't match anything fallback to
525 525 # old style detection
526 526 if op is None:
527 527 if not head['a_file'] and head['b_file']:
528 528 op = OPS.ADD
529 529 stats['binary'] = True
530 530 stats['new_file'] = True
531 531 stats['ops'][NEW_FILENODE] = 'new file'
532 532
533 533 elif head['a_file'] and not head['b_file']:
534 534 op = OPS.DEL
535 535 stats['binary'] = True
536 536 stats['ops'][DEL_FILENODE] = 'deleted file'
537 537
538 538 # it's not ADD not DELETE
539 539 if op is None:
540 540 op = OPS.MOD
541 541 stats['binary'] = True
542 542 stats['ops'][MOD_FILENODE] = 'modified file'
543 543
544 544 # a real non-binary diff
545 545 if head['a_file'] or head['b_file']:
546 546 # simulate splitlines, so we keep the line end part
547 547 diff = self.diff_splitter(chunk.diff)
548 548
549 549 # append each file to the diff size
550 550 raw_chunk_size = len(raw_diff)
551 551
552 552 exceeds_limit = raw_chunk_size > self.file_limit
553 553 self.cur_diff_size += raw_chunk_size
554 554
555 555 try:
556 556 # Check each file instead of the whole diff.
557 557 # Diff will hide big files but still show small ones.
558 558 # From the tests big files are fairly safe to be parsed
559 559 # but the browser is the bottleneck.
560 560 if not self.show_full_diff and exceeds_limit:
561 561 log.debug('File `%s` exceeds current file_limit of %s',
562 562 safe_unicode(head['b_path']), self.file_limit)
563 563 raise DiffLimitExceeded(
564 564 'File Limit %s Exceeded', self.file_limit)
565 565
566 566 self._check_large_diff()
567 567
568 568 raw_diff, chunks, _stats = self._new_parse_lines(diff)
569 569 stats['binary'] = False
570 570 stats['added'] = _stats[0]
571 571 stats['deleted'] = _stats[1]
572 572 # explicit mark that it's a modified file
573 573 if op == OPS.MOD:
574 574 stats['ops'][MOD_FILENODE] = 'modified file'
575 575
576 576 except DiffLimitExceeded:
577 577 diff_container = lambda _diff: \
578 578 LimitedDiffContainer(
579 579 self.diff_limit, self.cur_diff_size, _diff)
580 580
581 581 limited_diff = True
582 582 chunks = []
583 583
584 584 else: # GIT format binary patch, or possibly empty diff
585 585 if head['bin_patch']:
586 586 # we have operation already extracted, but we mark simply
587 587 # it's a diff we wont show for binary files
588 588 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
589 589 chunks = []
590 590
591 591 # Hide content of deleted node by setting empty chunks
592 592 if chunks and not self.show_full_diff and op == OPS.DEL:
593 593 # if not full diff mode show deleted file contents
594 594 # TODO: anderson: if the view is not too big, there is no way
595 595 # to see the content of the file
596 596 chunks = []
597 597
598 598 chunks.insert(
599 599 0, [{'old_lineno': '',
600 600 'new_lineno': '',
601 601 'action': Action.CONTEXT,
602 602 'line': msg,
603 603 } for _op, msg in stats['ops'].iteritems()
604 604 if _op not in [MOD_FILENODE]])
605 605
606 606 original_filename = safe_unicode(head['a_path'])
607 607 _files.append({
608 608 'original_filename': original_filename,
609 609 'filename': safe_unicode(head['b_path']),
610 610 'old_revision': head['a_blob_id'],
611 611 'new_revision': head['b_blob_id'],
612 612 'chunks': chunks,
613 613 'raw_diff': safe_unicode(raw_diff),
614 614 'operation': op,
615 615 'stats': stats,
616 616 'exceeds_limit': exceeds_limit,
617 617 'is_limited_diff': limited_diff,
618 618 })
619 619
620 620 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
621 621 OPS.DEL: 2}.get(info['operation'])
622 622
623 623 return diff_container(sorted(_files, key=sorter))
624 624
625 625 # FIXME: NEWDIFFS: dan: this gets replaced by _new_parse_lines
626 626 def _parse_lines(self, diff_iter):
627 627 """
628 628 Parse the diff an return data for the template.
629 629 """
630 630
631 631 stats = [0, 0]
632 632 chunks = []
633 633 raw_diff = []
634 634
635 635 try:
636 line = diff_iter.next()
636 line = next(diff_iter)
637 637
638 638 while line:
639 639 raw_diff.append(line)
640 640 lines = []
641 641 chunks.append(lines)
642 642
643 643 match = self._chunk_re.match(line)
644 644
645 645 if not match:
646 646 break
647 647
648 648 gr = match.groups()
649 649 (old_line, old_end,
650 650 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
651 651 old_line -= 1
652 652 new_line -= 1
653 653
654 654 context = len(gr) == 5
655 655 old_end += old_line
656 656 new_end += new_line
657 657
658 658 if context:
659 659 # skip context only if it's first line
660 660 if int(gr[0]) > 1:
661 661 lines.append({
662 662 'old_lineno': '...',
663 663 'new_lineno': '...',
664 664 'action': Action.CONTEXT,
665 665 'line': line,
666 666 })
667 667
668 line = diff_iter.next()
668 line = next(diff_iter)
669 669
670 670 while old_line < old_end or new_line < new_end:
671 671 command = ' '
672 672 if line:
673 673 command = line[0]
674 674
675 675 affects_old = affects_new = False
676 676
677 677 # ignore those if we don't expect them
678 678 if command in '#@':
679 679 continue
680 680 elif command == '+':
681 681 affects_new = True
682 682 action = Action.ADD
683 683 stats[0] += 1
684 684 elif command == '-':
685 685 affects_old = True
686 686 action = Action.DELETE
687 687 stats[1] += 1
688 688 else:
689 689 affects_old = affects_new = True
690 690 action = Action.UNMODIFIED
691 691
692 692 if not self._newline_marker.match(line):
693 693 old_line += affects_old
694 694 new_line += affects_new
695 695 lines.append({
696 696 'old_lineno': affects_old and old_line or '',
697 697 'new_lineno': affects_new and new_line or '',
698 698 'action': action,
699 699 'line': self._clean_line(line, command)
700 700 })
701 701 raw_diff.append(line)
702 702
703 line = diff_iter.next()
703 line = next(diff_iter)
704 704
705 705 if self._newline_marker.match(line):
706 706 # we need to append to lines, since this is not
707 707 # counted in the line specs of diff
708 708 lines.append({
709 709 'old_lineno': '...',
710 710 'new_lineno': '...',
711 711 'action': Action.CONTEXT,
712 712 'line': self._clean_line(line, command)
713 713 })
714 714
715 715 except StopIteration:
716 716 pass
717 717 return ''.join(raw_diff), chunks, stats
718 718
719 719 # FIXME: NEWDIFFS: dan: this replaces _parse_lines
720 720 def _new_parse_lines(self, diff_iter):
721 721 """
722 722 Parse the diff an return data for the template.
723 723 """
724 724
725 725 stats = [0, 0]
726 726 chunks = []
727 727 raw_diff = []
728 728
729 729 try:
730 line = diff_iter.next()
730 line = next(diff_iter)
731 731
732 732 while line:
733 733 raw_diff.append(line)
734 734 # match header e.g @@ -0,0 +1 @@\n'
735 735 match = self._chunk_re.match(line)
736 736
737 737 if not match:
738 738 break
739 739
740 740 gr = match.groups()
741 741 (old_line, old_end,
742 742 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
743 743
744 744 lines = []
745 745 hunk = {
746 746 'section_header': gr[-1],
747 747 'source_start': old_line,
748 748 'source_length': old_end,
749 749 'target_start': new_line,
750 750 'target_length': new_end,
751 751 'lines': lines,
752 752 }
753 753 chunks.append(hunk)
754 754
755 755 old_line -= 1
756 756 new_line -= 1
757 757
758 758 context = len(gr) == 5
759 759 old_end += old_line
760 760 new_end += new_line
761 761
762 line = diff_iter.next()
762 line = next(diff_iter)
763 763
764 764 while old_line < old_end or new_line < new_end:
765 765 command = ' '
766 766 if line:
767 767 command = line[0]
768 768
769 769 affects_old = affects_new = False
770 770
771 771 # ignore those if we don't expect them
772 772 if command in '#@':
773 773 continue
774 774 elif command == '+':
775 775 affects_new = True
776 776 action = Action.ADD
777 777 stats[0] += 1
778 778 elif command == '-':
779 779 affects_old = True
780 780 action = Action.DELETE
781 781 stats[1] += 1
782 782 else:
783 783 affects_old = affects_new = True
784 784 action = Action.UNMODIFIED
785 785
786 786 if not self._newline_marker.match(line):
787 787 old_line += affects_old
788 788 new_line += affects_new
789 789 lines.append({
790 790 'old_lineno': affects_old and old_line or '',
791 791 'new_lineno': affects_new and new_line or '',
792 792 'action': action,
793 793 'line': self._clean_line(line, command)
794 794 })
795 795 raw_diff.append(line)
796 796
797 line = diff_iter.next()
797 line = next(diff_iter)
798 798
799 799 if self._newline_marker.match(line):
800 800 # we need to append to lines, since this is not
801 801 # counted in the line specs of diff
802 802 if affects_old:
803 803 action = Action.OLD_NO_NL
804 804 elif affects_new:
805 805 action = Action.NEW_NO_NL
806 806 else:
807 807 raise Exception('invalid context for no newline')
808 808
809 809 lines.append({
810 810 'old_lineno': None,
811 811 'new_lineno': None,
812 812 'action': action,
813 813 'line': self._clean_line(line, command)
814 814 })
815 815
816 816 except StopIteration:
817 817 pass
818 818
819 819 return ''.join(raw_diff), chunks, stats
820 820
821 821 def _safe_id(self, idstring):
822 822 """Make a string safe for including in an id attribute.
823 823
824 824 The HTML spec says that id attributes 'must begin with
825 825 a letter ([A-Za-z]) and may be followed by any number
826 826 of letters, digits ([0-9]), hyphens ("-"), underscores
827 827 ("_"), colons (":"), and periods (".")'. These regexps
828 828 are slightly over-zealous, in that they remove colons
829 829 and periods unnecessarily.
830 830
831 831 Whitespace is transformed into underscores, and then
832 832 anything which is not a hyphen or a character that
833 833 matches \w (alphanumerics and underscore) is removed.
834 834
835 835 """
836 836 # Transform all whitespace to underscore
837 837 idstring = re.sub(r'\s', "_", '%s' % idstring)
838 838 # Remove everything that is not a hyphen or a member of \w
839 839 idstring = re.sub(r'(?!-)\W', "", idstring).lower()
840 840 return idstring
841 841
842 842 @classmethod
843 843 def diff_splitter(cls, string):
844 844 """
845 845 Diff split that emulates .splitlines() but works only on \n
846 846 """
847 847 if not string:
848 848 return
849 849 elif string == '\n':
850 850 yield u'\n'
851 851 else:
852 852
853 853 has_newline = string.endswith('\n')
854 854 elements = string.split('\n')
855 855 if has_newline:
856 856 # skip last element as it's empty string from newlines
857 857 elements = elements[:-1]
858 858
859 859 len_elements = len(elements)
860 860
861 861 for cnt, line in enumerate(elements, start=1):
862 862 last_line = cnt == len_elements
863 863 if last_line and not has_newline:
864 864 yield safe_unicode(line)
865 865 else:
866 866 yield safe_unicode(line) + '\n'
867 867
868 868 def prepare(self, inline_diff=True):
869 869 """
870 870 Prepare the passed udiff for HTML rendering.
871 871
872 872 :return: A list of dicts with diff information.
873 873 """
874 874 parsed = self._parser(inline_diff=inline_diff)
875 875 self.parsed = True
876 876 self.parsed_diff = parsed
877 877 return parsed
878 878
879 879 def as_raw(self, diff_lines=None):
880 880 """
881 881 Returns raw diff as a byte string
882 882 """
883 883 return self._diff.raw
884 884
885 885 def as_html(self, table_class='code-difftable', line_class='line',
886 886 old_lineno_class='lineno old', new_lineno_class='lineno new',
887 887 code_class='code', enable_comments=False, parsed_lines=None):
888 888 """
889 889 Return given diff as html table with customized css classes
890 890 """
891 891 # TODO(marcink): not sure how to pass in translator
892 892 # here in an efficient way, leave the _ for proper gettext extraction
893 893 _ = lambda s: s
894 894
895 895 def _link_to_if(condition, label, url):
896 896 """
897 897 Generates a link if condition is meet or just the label if not.
898 898 """
899 899
900 900 if condition:
901 901 return '''<a href="%(url)s" class="tooltip"
902 902 title="%(title)s">%(label)s</a>''' % {
903 903 'title': _('Click to select line'),
904 904 'url': url,
905 905 'label': label
906 906 }
907 907 else:
908 908 return label
909 909 if not self.parsed:
910 910 self.prepare()
911 911
912 912 diff_lines = self.parsed_diff
913 913 if parsed_lines:
914 914 diff_lines = parsed_lines
915 915
916 916 _html_empty = True
917 917 _html = []
918 918 _html.append('''<table class="%(table_class)s">\n''' % {
919 919 'table_class': table_class
920 920 })
921 921
922 922 for diff in diff_lines:
923 923 for line in diff['chunks']:
924 924 _html_empty = False
925 925 for change in line:
926 926 _html.append('''<tr class="%(lc)s %(action)s">\n''' % {
927 927 'lc': line_class,
928 928 'action': change['action']
929 929 })
930 930 anchor_old_id = ''
931 931 anchor_new_id = ''
932 932 anchor_old = "%(filename)s_o%(oldline_no)s" % {
933 933 'filename': self._safe_id(diff['filename']),
934 934 'oldline_no': change['old_lineno']
935 935 }
936 936 anchor_new = "%(filename)s_n%(oldline_no)s" % {
937 937 'filename': self._safe_id(diff['filename']),
938 938 'oldline_no': change['new_lineno']
939 939 }
940 940 cond_old = (change['old_lineno'] != '...' and
941 941 change['old_lineno'])
942 942 cond_new = (change['new_lineno'] != '...' and
943 943 change['new_lineno'])
944 944 if cond_old:
945 945 anchor_old_id = 'id="%s"' % anchor_old
946 946 if cond_new:
947 947 anchor_new_id = 'id="%s"' % anchor_new
948 948
949 949 if change['action'] != Action.CONTEXT:
950 950 anchor_link = True
951 951 else:
952 952 anchor_link = False
953 953
954 954 ###########################################################
955 955 # COMMENT ICONS
956 956 ###########################################################
957 957 _html.append('''\t<td class="add-comment-line"><span class="add-comment-content">''')
958 958
959 959 if enable_comments and change['action'] != Action.CONTEXT:
960 960 _html.append('''<a href="#"><span class="icon-comment-add"></span></a>''')
961 961
962 962 _html.append('''</span></td><td class="comment-toggle tooltip" title="Toggle Comment Thread"><i class="icon-comment"></i></td>\n''')
963 963
964 964 ###########################################################
965 965 # OLD LINE NUMBER
966 966 ###########################################################
967 967 _html.append('''\t<td %(a_id)s class="%(olc)s">''' % {
968 968 'a_id': anchor_old_id,
969 969 'olc': old_lineno_class
970 970 })
971 971
972 972 _html.append('''%(link)s''' % {
973 973 'link': _link_to_if(anchor_link, change['old_lineno'],
974 974 '#%s' % anchor_old)
975 975 })
976 976 _html.append('''</td>\n''')
977 977 ###########################################################
978 978 # NEW LINE NUMBER
979 979 ###########################################################
980 980
981 981 _html.append('''\t<td %(a_id)s class="%(nlc)s">''' % {
982 982 'a_id': anchor_new_id,
983 983 'nlc': new_lineno_class
984 984 })
985 985
986 986 _html.append('''%(link)s''' % {
987 987 'link': _link_to_if(anchor_link, change['new_lineno'],
988 988 '#%s' % anchor_new)
989 989 })
990 990 _html.append('''</td>\n''')
991 991 ###########################################################
992 992 # CODE
993 993 ###########################################################
994 994 code_classes = [code_class]
995 995 if (not enable_comments or
996 996 change['action'] == Action.CONTEXT):
997 997 code_classes.append('no-comment')
998 998 _html.append('\t<td class="%s">' % ' '.join(code_classes))
999 999 _html.append('''\n\t\t<pre>%(code)s</pre>\n''' % {
1000 1000 'code': change['line']
1001 1001 })
1002 1002
1003 1003 _html.append('''\t</td>''')
1004 1004 _html.append('''\n</tr>\n''')
1005 1005 _html.append('''</table>''')
1006 1006 if _html_empty:
1007 1007 return None
1008 1008 return ''.join(_html)
1009 1009
1010 1010 def stat(self):
1011 1011 """
1012 1012 Returns tuple of added, and removed lines for this instance
1013 1013 """
1014 1014 return self.adds, self.removes
1015 1015
1016 1016 def get_context_of_line(
1017 1017 self, path, diff_line=None, context_before=3, context_after=3):
1018 1018 """
1019 1019 Returns the context lines for the specified diff line.
1020 1020
1021 1021 :type diff_line: :class:`DiffLineNumber`
1022 1022 """
1023 1023 assert self.parsed, "DiffProcessor is not initialized."
1024 1024
1025 1025 if None not in diff_line:
1026 1026 raise ValueError(
1027 1027 "Cannot specify both line numbers: {}".format(diff_line))
1028 1028
1029 1029 file_diff = self._get_file_diff(path)
1030 1030 chunk, idx = self._find_chunk_line_index(file_diff, diff_line)
1031 1031
1032 1032 first_line_to_include = max(idx - context_before, 0)
1033 1033 first_line_after_context = idx + context_after + 1
1034 1034 context_lines = chunk[first_line_to_include:first_line_after_context]
1035 1035
1036 1036 line_contents = [
1037 1037 _context_line(line) for line in context_lines
1038 1038 if _is_diff_content(line)]
1039 1039 # TODO: johbo: Interim fixup, the diff chunks drop the final newline.
1040 1040 # Once they are fixed, we can drop this line here.
1041 1041 if line_contents:
1042 1042 line_contents[-1] = (
1043 1043 line_contents[-1][0], line_contents[-1][1].rstrip('\n') + '\n')
1044 1044 return line_contents
1045 1045
1046 1046 def find_context(self, path, context, offset=0):
1047 1047 """
1048 1048 Finds the given `context` inside of the diff.
1049 1049
1050 1050 Use the parameter `offset` to specify which offset the target line has
1051 1051 inside of the given `context`. This way the correct diff line will be
1052 1052 returned.
1053 1053
1054 1054 :param offset: Shall be used to specify the offset of the main line
1055 1055 within the given `context`.
1056 1056 """
1057 1057 if offset < 0 or offset >= len(context):
1058 1058 raise ValueError(
1059 1059 "Only positive values up to the length of the context "
1060 1060 "minus one are allowed.")
1061 1061
1062 1062 matches = []
1063 1063 file_diff = self._get_file_diff(path)
1064 1064
1065 1065 for chunk in file_diff['chunks']:
1066 1066 context_iter = iter(context)
1067 1067 for line_idx, line in enumerate(chunk):
1068 1068 try:
1069 if _context_line(line) == context_iter.next():
1069 if _context_line(line) == next(context_iter):
1070 1070 continue
1071 1071 except StopIteration:
1072 1072 matches.append((line_idx, chunk))
1073 1073 context_iter = iter(context)
1074 1074
1075 1075 # Increment position and triger StopIteration
1076 1076 # if we had a match at the end
1077 1077 line_idx += 1
1078 1078 try:
1079 context_iter.next()
1079 next(context_iter)
1080 1080 except StopIteration:
1081 1081 matches.append((line_idx, chunk))
1082 1082
1083 1083 effective_offset = len(context) - offset
1084 1084 found_at_diff_lines = [
1085 1085 _line_to_diff_line_number(chunk[idx - effective_offset])
1086 1086 for idx, chunk in matches]
1087 1087
1088 1088 return found_at_diff_lines
1089 1089
1090 1090 def _get_file_diff(self, path):
1091 1091 for file_diff in self.parsed_diff:
1092 1092 if file_diff['filename'] == path:
1093 1093 break
1094 1094 else:
1095 1095 raise FileNotInDiffException("File {} not in diff".format(path))
1096 1096 return file_diff
1097 1097
1098 1098 def _find_chunk_line_index(self, file_diff, diff_line):
1099 1099 for chunk in file_diff['chunks']:
1100 1100 for idx, line in enumerate(chunk):
1101 1101 if line['old_lineno'] == diff_line.old:
1102 1102 return chunk, idx
1103 1103 if line['new_lineno'] == diff_line.new:
1104 1104 return chunk, idx
1105 1105 raise LineNotInDiffException(
1106 1106 "The line {} is not part of the diff.".format(diff_line))
1107 1107
1108 1108
1109 1109 def _is_diff_content(line):
1110 1110 return line['action'] in (
1111 1111 Action.UNMODIFIED, Action.ADD, Action.DELETE)
1112 1112
1113 1113
1114 1114 def _context_line(line):
1115 1115 return (line['action'], line['line'])
1116 1116
1117 1117
1118 1118 DiffLineNumber = collections.namedtuple('DiffLineNumber', ['old', 'new'])
1119 1119
1120 1120
1121 1121 def _line_to_diff_line_number(line):
1122 1122 new_line_no = line['new_lineno'] or None
1123 1123 old_line_no = line['old_lineno'] or None
1124 1124 return DiffLineNumber(old=old_line_no, new=new_line_no)
1125 1125
1126 1126
1127 1127 class FileNotInDiffException(Exception):
1128 1128 """
1129 1129 Raised when the context for a missing file is requested.
1130 1130
1131 1131 If you request the context for a line in a file which is not part of the
1132 1132 given diff, then this exception is raised.
1133 1133 """
1134 1134
1135 1135
1136 1136 class LineNotInDiffException(Exception):
1137 1137 """
1138 1138 Raised when the context for a missing line is requested.
1139 1139
1140 1140 If you request the context for a line in a file and this line is not
1141 1141 part of the given diff, then this exception is raised.
1142 1142 """
1143 1143
1144 1144
1145 1145 class DiffLimitExceeded(Exception):
1146 1146 pass
1147 1147
1148 1148
1149 1149 # NOTE(marcink): if diffs.mako change, probably this
1150 1150 # needs a bump to next version
1151 1151 CURRENT_DIFF_VERSION = 'v5'
1152 1152
1153 1153
1154 1154 def _cleanup_cache_file(cached_diff_file):
1155 1155 # cleanup file to not store it "damaged"
1156 1156 try:
1157 1157 os.remove(cached_diff_file)
1158 1158 except Exception:
1159 1159 log.exception('Failed to cleanup path %s', cached_diff_file)
1160 1160
1161 1161
1162 1162 def _get_compression_mode(cached_diff_file):
1163 1163 mode = 'bz2'
1164 1164 if 'mode:plain' in cached_diff_file:
1165 1165 mode = 'plain'
1166 1166 elif 'mode:gzip' in cached_diff_file:
1167 1167 mode = 'gzip'
1168 1168 return mode
1169 1169
1170 1170
1171 1171 def cache_diff(cached_diff_file, diff, commits):
1172 1172 compression_mode = _get_compression_mode(cached_diff_file)
1173 1173
1174 1174 struct = {
1175 1175 'version': CURRENT_DIFF_VERSION,
1176 1176 'diff': diff,
1177 1177 'commits': commits
1178 1178 }
1179 1179
1180 1180 start = time.time()
1181 1181 try:
1182 1182 if compression_mode == 'plain':
1183 1183 with open(cached_diff_file, 'wb') as f:
1184 1184 pickle.dump(struct, f)
1185 1185 elif compression_mode == 'gzip':
1186 1186 with gzip.GzipFile(cached_diff_file, 'wb') as f:
1187 1187 pickle.dump(struct, f)
1188 1188 else:
1189 1189 with bz2.BZ2File(cached_diff_file, 'wb') as f:
1190 1190 pickle.dump(struct, f)
1191 1191 except Exception:
1192 1192 log.warn('Failed to save cache', exc_info=True)
1193 1193 _cleanup_cache_file(cached_diff_file)
1194 1194
1195 1195 log.debug('Saved diff cache under %s in %.4fs', cached_diff_file, time.time() - start)
1196 1196
1197 1197
1198 1198 def load_cached_diff(cached_diff_file):
1199 1199 compression_mode = _get_compression_mode(cached_diff_file)
1200 1200
1201 1201 default_struct = {
1202 1202 'version': CURRENT_DIFF_VERSION,
1203 1203 'diff': None,
1204 1204 'commits': None
1205 1205 }
1206 1206
1207 1207 has_cache = os.path.isfile(cached_diff_file)
1208 1208 if not has_cache:
1209 1209 log.debug('Reading diff cache file failed %s', cached_diff_file)
1210 1210 return default_struct
1211 1211
1212 1212 data = None
1213 1213
1214 1214 start = time.time()
1215 1215 try:
1216 1216 if compression_mode == 'plain':
1217 1217 with open(cached_diff_file, 'rb') as f:
1218 1218 data = pickle.load(f)
1219 1219 elif compression_mode == 'gzip':
1220 1220 with gzip.GzipFile(cached_diff_file, 'rb') as f:
1221 1221 data = pickle.load(f)
1222 1222 else:
1223 1223 with bz2.BZ2File(cached_diff_file, 'rb') as f:
1224 1224 data = pickle.load(f)
1225 1225 except Exception:
1226 1226 log.warn('Failed to read diff cache file', exc_info=True)
1227 1227
1228 1228 if not data:
1229 1229 data = default_struct
1230 1230
1231 1231 if not isinstance(data, dict):
1232 1232 # old version of data ?
1233 1233 data = default_struct
1234 1234
1235 1235 # check version
1236 1236 if data.get('version') != CURRENT_DIFF_VERSION:
1237 1237 # purge cache
1238 1238 _cleanup_cache_file(cached_diff_file)
1239 1239 return default_struct
1240 1240
1241 1241 log.debug('Loaded diff cache from %s in %.4fs', cached_diff_file, time.time() - start)
1242 1242
1243 1243 return data
1244 1244
1245 1245
1246 1246 def generate_diff_cache_key(*args):
1247 1247 """
1248 1248 Helper to generate a cache key using arguments
1249 1249 """
1250 1250 def arg_mapper(input_param):
1251 1251 input_param = safe_str(input_param)
1252 1252 # we cannot allow '/' in arguments since it would allow
1253 1253 # subdirectory usage
1254 1254 input_param.replace('/', '_')
1255 1255 return input_param or None # prevent empty string arguments
1256 1256
1257 1257 return '_'.join([
1258 1258 '{}' for i in range(len(args))]).format(*map(arg_mapper, args))
1259 1259
1260 1260
1261 1261 def diff_cache_exist(cache_storage, *args):
1262 1262 """
1263 1263 Based on all generated arguments check and return a cache path
1264 1264 """
1265 1265 args = list(args) + ['mode:gzip']
1266 1266 cache_key = generate_diff_cache_key(*args)
1267 1267 cache_file_path = os.path.join(cache_storage, cache_key)
1268 1268 # prevent path traversal attacks using some param that have e.g '../../'
1269 1269 if not os.path.abspath(cache_file_path).startswith(cache_storage):
1270 1270 raise ValueError('Final path must be within {}'.format(cache_storage))
1271 1271
1272 1272 return cache_file_path
General Comments 0
You need to be logged in to leave comments. Login now