##// END OF EJS Templates
classes: fix class style problems found by b071cd58af50...
Thomas Arendsen Hein -
r14764:a7d58160 stable
parent child Browse files
Show More
@@ -1,1107 +1,1107 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 # $Id: manpage.py 6110 2009-08-31 14:40:33Z grubert $
2 # $Id: manpage.py 6110 2009-08-31 14:40:33Z grubert $
3 # Author: Engelbert Gruber <grubert@users.sourceforge.net>
3 # Author: Engelbert Gruber <grubert@users.sourceforge.net>
4 # Copyright: This module is put into the public domain.
4 # Copyright: This module is put into the public domain.
5
5
6 """
6 """
7 Simple man page writer for reStructuredText.
7 Simple man page writer for reStructuredText.
8
8
9 Man pages (short for "manual pages") contain system documentation on unix-like
9 Man pages (short for "manual pages") contain system documentation on unix-like
10 systems. The pages are grouped in numbered sections:
10 systems. The pages are grouped in numbered sections:
11
11
12 1 executable programs and shell commands
12 1 executable programs and shell commands
13 2 system calls
13 2 system calls
14 3 library functions
14 3 library functions
15 4 special files
15 4 special files
16 5 file formats
16 5 file formats
17 6 games
17 6 games
18 7 miscellaneous
18 7 miscellaneous
19 8 system administration
19 8 system administration
20
20
21 Man pages are written *troff*, a text file formatting system.
21 Man pages are written *troff*, a text file formatting system.
22
22
23 See http://www.tldp.org/HOWTO/Man-Page for a start.
23 See http://www.tldp.org/HOWTO/Man-Page for a start.
24
24
25 Man pages have no subsection only parts.
25 Man pages have no subsection only parts.
26 Standard parts
26 Standard parts
27
27
28 NAME ,
28 NAME ,
29 SYNOPSIS ,
29 SYNOPSIS ,
30 DESCRIPTION ,
30 DESCRIPTION ,
31 OPTIONS ,
31 OPTIONS ,
32 FILES ,
32 FILES ,
33 SEE ALSO ,
33 SEE ALSO ,
34 BUGS ,
34 BUGS ,
35
35
36 and
36 and
37
37
38 AUTHOR .
38 AUTHOR .
39
39
40 A unix-like system keeps an index of the DESCRIPTIONs, which is accesable
40 A unix-like system keeps an index of the DESCRIPTIONs, which is accesable
41 by the command whatis or apropos.
41 by the command whatis or apropos.
42
42
43 """
43 """
44
44
45 __docformat__ = 'reStructuredText'
45 __docformat__ = 'reStructuredText'
46
46
47 import re
47 import re
48
48
49 from docutils import nodes, writers, languages
49 from docutils import nodes, writers, languages
50 import roman
50 import roman
51 import inspect
51 import inspect
52
52
53 FIELD_LIST_INDENT = 7
53 FIELD_LIST_INDENT = 7
54 DEFINITION_LIST_INDENT = 7
54 DEFINITION_LIST_INDENT = 7
55 OPTION_LIST_INDENT = 7
55 OPTION_LIST_INDENT = 7
56 BLOCKQOUTE_INDENT = 3.5
56 BLOCKQOUTE_INDENT = 3.5
57
57
58 # Define two macros so man/roff can calculate the
58 # Define two macros so man/roff can calculate the
59 # indent/unindent margins by itself
59 # indent/unindent margins by itself
60 MACRO_DEF = (r""".
60 MACRO_DEF = (r""".
61 .nr rst2man-indent-level 0
61 .nr rst2man-indent-level 0
62 .
62 .
63 .de1 rstReportMargin
63 .de1 rstReportMargin
64 \\$1 \\n[an-margin]
64 \\$1 \\n[an-margin]
65 level \\n[rst2man-indent-level]
65 level \\n[rst2man-indent-level]
66 level margin: \\n[rst2man-indent\\n[rst2man-indent-level]]
66 level margin: \\n[rst2man-indent\\n[rst2man-indent-level]]
67 -
67 -
68 \\n[rst2man-indent0]
68 \\n[rst2man-indent0]
69 \\n[rst2man-indent1]
69 \\n[rst2man-indent1]
70 \\n[rst2man-indent2]
70 \\n[rst2man-indent2]
71 ..
71 ..
72 .de1 INDENT
72 .de1 INDENT
73 .\" .rstReportMargin pre:
73 .\" .rstReportMargin pre:
74 . RS \\$1
74 . RS \\$1
75 . nr rst2man-indent\\n[rst2man-indent-level] \\n[an-margin]
75 . nr rst2man-indent\\n[rst2man-indent-level] \\n[an-margin]
76 . nr rst2man-indent-level +1
76 . nr rst2man-indent-level +1
77 .\" .rstReportMargin post:
77 .\" .rstReportMargin post:
78 ..
78 ..
79 .de UNINDENT
79 .de UNINDENT
80 . RE
80 . RE
81 .\" indent \\n[an-margin]
81 .\" indent \\n[an-margin]
82 .\" old: \\n[rst2man-indent\\n[rst2man-indent-level]]
82 .\" old: \\n[rst2man-indent\\n[rst2man-indent-level]]
83 .nr rst2man-indent-level -1
83 .nr rst2man-indent-level -1
84 .\" new: \\n[rst2man-indent\\n[rst2man-indent-level]]
84 .\" new: \\n[rst2man-indent\\n[rst2man-indent-level]]
85 .in \\n[rst2man-indent\\n[rst2man-indent-level]]u
85 .in \\n[rst2man-indent\\n[rst2man-indent-level]]u
86 ..
86 ..
87 """)
87 """)
88
88
89 class Writer(writers.Writer):
89 class Writer(writers.Writer):
90
90
91 supported = ('manpage')
91 supported = ('manpage')
92 """Formats this writer supports."""
92 """Formats this writer supports."""
93
93
94 output = None
94 output = None
95 """Final translated form of `document`."""
95 """Final translated form of `document`."""
96
96
97 def __init__(self):
97 def __init__(self):
98 writers.Writer.__init__(self)
98 writers.Writer.__init__(self)
99 self.translator_class = Translator
99 self.translator_class = Translator
100
100
101 def translate(self):
101 def translate(self):
102 visitor = self.translator_class(self.document)
102 visitor = self.translator_class(self.document)
103 self.document.walkabout(visitor)
103 self.document.walkabout(visitor)
104 self.output = visitor.astext()
104 self.output = visitor.astext()
105
105
106
106
107 class Table:
107 class Table(object):
108 def __init__(self):
108 def __init__(self):
109 self._rows = []
109 self._rows = []
110 self._options = ['center']
110 self._options = ['center']
111 self._tab_char = '\t'
111 self._tab_char = '\t'
112 self._coldefs = []
112 self._coldefs = []
113 def new_row(self):
113 def new_row(self):
114 self._rows.append([])
114 self._rows.append([])
115 def append_separator(self, separator):
115 def append_separator(self, separator):
116 """Append the separator for table head."""
116 """Append the separator for table head."""
117 self._rows.append([separator])
117 self._rows.append([separator])
118 def append_cell(self, cell_lines):
118 def append_cell(self, cell_lines):
119 """cell_lines is an array of lines"""
119 """cell_lines is an array of lines"""
120 start = 0
120 start = 0
121 if len(cell_lines) > 0 and cell_lines[0] == '.sp\n':
121 if len(cell_lines) > 0 and cell_lines[0] == '.sp\n':
122 start = 1
122 start = 1
123 self._rows[-1].append(cell_lines[start:])
123 self._rows[-1].append(cell_lines[start:])
124 if len(self._coldefs) < len(self._rows[-1]):
124 if len(self._coldefs) < len(self._rows[-1]):
125 self._coldefs.append('l')
125 self._coldefs.append('l')
126 def _minimize_cell(self, cell_lines):
126 def _minimize_cell(self, cell_lines):
127 """Remove leading and trailing blank and ``.sp`` lines"""
127 """Remove leading and trailing blank and ``.sp`` lines"""
128 while (cell_lines and cell_lines[0] in ('\n', '.sp\n')):
128 while (cell_lines and cell_lines[0] in ('\n', '.sp\n')):
129 del cell_lines[0]
129 del cell_lines[0]
130 while (cell_lines and cell_lines[-1] in ('\n', '.sp\n')):
130 while (cell_lines and cell_lines[-1] in ('\n', '.sp\n')):
131 del cell_lines[-1]
131 del cell_lines[-1]
132 def as_list(self):
132 def as_list(self):
133 text = ['.TS\n']
133 text = ['.TS\n']
134 text.append(' '.join(self._options) + ';\n')
134 text.append(' '.join(self._options) + ';\n')
135 text.append('|%s|.\n' % ('|'.join(self._coldefs)))
135 text.append('|%s|.\n' % ('|'.join(self._coldefs)))
136 for row in self._rows:
136 for row in self._rows:
137 # row = array of cells. cell = array of lines.
137 # row = array of cells. cell = array of lines.
138 text.append('_\n') # line above
138 text.append('_\n') # line above
139 text.append('T{\n')
139 text.append('T{\n')
140 for i in range(len(row)):
140 for i in range(len(row)):
141 cell = row[i]
141 cell = row[i]
142 self._minimize_cell(cell)
142 self._minimize_cell(cell)
143 text.extend(cell)
143 text.extend(cell)
144 if not text[-1].endswith('\n'):
144 if not text[-1].endswith('\n'):
145 text[-1] += '\n'
145 text[-1] += '\n'
146 if i < len(row)-1:
146 if i < len(row)-1:
147 text.append('T}'+self._tab_char+'T{\n')
147 text.append('T}'+self._tab_char+'T{\n')
148 else:
148 else:
149 text.append('T}\n')
149 text.append('T}\n')
150 text.append('_\n')
150 text.append('_\n')
151 text.append('.TE\n')
151 text.append('.TE\n')
152 return text
152 return text
153
153
154 class Translator(nodes.NodeVisitor):
154 class Translator(nodes.NodeVisitor):
155 """"""
155 """"""
156
156
157 words_and_spaces = re.compile(r'\S+| +|\n')
157 words_and_spaces = re.compile(r'\S+| +|\n')
158 document_start = """Man page generated from reStructeredText."""
158 document_start = """Man page generated from reStructeredText."""
159
159
160 def __init__(self, document):
160 def __init__(self, document):
161 nodes.NodeVisitor.__init__(self, document)
161 nodes.NodeVisitor.__init__(self, document)
162 self.settings = settings = document.settings
162 self.settings = settings = document.settings
163 lcode = settings.language_code
163 lcode = settings.language_code
164 arglen = len(inspect.getargspec(languages.get_language)[0])
164 arglen = len(inspect.getargspec(languages.get_language)[0])
165 if arglen == 2:
165 if arglen == 2:
166 self.language = languages.get_language(lcode,
166 self.language = languages.get_language(lcode,
167 self.document.reporter)
167 self.document.reporter)
168 else:
168 else:
169 self.language = languages.get_language(lcode)
169 self.language = languages.get_language(lcode)
170 self.head = []
170 self.head = []
171 self.body = []
171 self.body = []
172 self.foot = []
172 self.foot = []
173 self.section_level = 0
173 self.section_level = 0
174 self.context = []
174 self.context = []
175 self.topic_class = ''
175 self.topic_class = ''
176 self.colspecs = []
176 self.colspecs = []
177 self.compact_p = 1
177 self.compact_p = 1
178 self.compact_simple = None
178 self.compact_simple = None
179 # the list style "*" bullet or "#" numbered
179 # the list style "*" bullet or "#" numbered
180 self._list_char = []
180 self._list_char = []
181 # writing the header .TH and .SH NAME is postboned after
181 # writing the header .TH and .SH NAME is postboned after
182 # docinfo.
182 # docinfo.
183 self._docinfo = {
183 self._docinfo = {
184 "title" : "", "title_upper": "",
184 "title" : "", "title_upper": "",
185 "subtitle" : "",
185 "subtitle" : "",
186 "manual_section" : "", "manual_group" : "",
186 "manual_section" : "", "manual_group" : "",
187 "author" : [],
187 "author" : [],
188 "date" : "",
188 "date" : "",
189 "copyright" : "",
189 "copyright" : "",
190 "version" : "",
190 "version" : "",
191 }
191 }
192 self._docinfo_keys = [] # a list to keep the sequence as in source.
192 self._docinfo_keys = [] # a list to keep the sequence as in source.
193 self._docinfo_names = {} # to get name from text not normalized.
193 self._docinfo_names = {} # to get name from text not normalized.
194 self._in_docinfo = None
194 self._in_docinfo = None
195 self._active_table = None
195 self._active_table = None
196 self._in_literal = False
196 self._in_literal = False
197 self.header_written = 0
197 self.header_written = 0
198 self._line_block = 0
198 self._line_block = 0
199 self.authors = []
199 self.authors = []
200 self.section_level = 0
200 self.section_level = 0
201 self._indent = [0]
201 self._indent = [0]
202 # central definition of simple processing rules
202 # central definition of simple processing rules
203 # what to output on : visit, depart
203 # what to output on : visit, depart
204 # Do not use paragraph requests ``.PP`` because these set indentation.
204 # Do not use paragraph requests ``.PP`` because these set indentation.
205 # use ``.sp``. Remove superfluous ``.sp`` in ``astext``.
205 # use ``.sp``. Remove superfluous ``.sp`` in ``astext``.
206 #
206 #
207 # Fonts are put on a stack, the top one is used.
207 # Fonts are put on a stack, the top one is used.
208 # ``.ft P`` or ``\\fP`` pop from stack.
208 # ``.ft P`` or ``\\fP`` pop from stack.
209 # ``B`` bold, ``I`` italic, ``R`` roman should be available.
209 # ``B`` bold, ``I`` italic, ``R`` roman should be available.
210 # Hopefully ``C`` courier too.
210 # Hopefully ``C`` courier too.
211 self.defs = {
211 self.defs = {
212 'indent' : ('.INDENT %.1f\n', '.UNINDENT\n'),
212 'indent' : ('.INDENT %.1f\n', '.UNINDENT\n'),
213 'definition_list_item' : ('.TP', ''),
213 'definition_list_item' : ('.TP', ''),
214 'field_name' : ('.TP\n.B ', '\n'),
214 'field_name' : ('.TP\n.B ', '\n'),
215 'literal' : ('\\fB', '\\fP'),
215 'literal' : ('\\fB', '\\fP'),
216 'literal_block' : ('.sp\n.nf\n.ft C\n', '\n.ft P\n.fi\n'),
216 'literal_block' : ('.sp\n.nf\n.ft C\n', '\n.ft P\n.fi\n'),
217
217
218 'option_list_item' : ('.TP\n', ''),
218 'option_list_item' : ('.TP\n', ''),
219
219
220 'reference' : (r'\%', r'\:'),
220 'reference' : (r'\%', r'\:'),
221 'emphasis': ('\\fI', '\\fP'),
221 'emphasis': ('\\fI', '\\fP'),
222 'strong' : ('\\fB', '\\fP'),
222 'strong' : ('\\fB', '\\fP'),
223 'term' : ('\n.B ', '\n'),
223 'term' : ('\n.B ', '\n'),
224 'title_reference' : ('\\fI', '\\fP'),
224 'title_reference' : ('\\fI', '\\fP'),
225
225
226 'topic-title' : ('.SS ',),
226 'topic-title' : ('.SS ',),
227 'sidebar-title' : ('.SS ',),
227 'sidebar-title' : ('.SS ',),
228
228
229 'problematic' : ('\n.nf\n', '\n.fi\n'),
229 'problematic' : ('\n.nf\n', '\n.fi\n'),
230 }
230 }
231 # NOTE don't specify the newline before a dot-command, but ensure
231 # NOTE don't specify the newline before a dot-command, but ensure
232 # it is there.
232 # it is there.
233
233
234 def comment_begin(self, text):
234 def comment_begin(self, text):
235 """Return commented version of the passed text WITHOUT end of
235 """Return commented version of the passed text WITHOUT end of
236 line/comment."""
236 line/comment."""
237 prefix = '.\\" '
237 prefix = '.\\" '
238 out_text = ''.join(
238 out_text = ''.join(
239 [(prefix + in_line + '\n')
239 [(prefix + in_line + '\n')
240 for in_line in text.split('\n')])
240 for in_line in text.split('\n')])
241 return out_text
241 return out_text
242
242
243 def comment(self, text):
243 def comment(self, text):
244 """Return commented version of the passed text."""
244 """Return commented version of the passed text."""
245 return self.comment_begin(text)+'.\n'
245 return self.comment_begin(text)+'.\n'
246
246
247 def ensure_eol(self):
247 def ensure_eol(self):
248 """Ensure the last line in body is terminated by new line."""
248 """Ensure the last line in body is terminated by new line."""
249 if self.body[-1][-1] != '\n':
249 if self.body[-1][-1] != '\n':
250 self.body.append('\n')
250 self.body.append('\n')
251
251
252 def astext(self):
252 def astext(self):
253 """Return the final formatted document as a string."""
253 """Return the final formatted document as a string."""
254 if not self.header_written:
254 if not self.header_written:
255 # ensure we get a ".TH" as viewers require it.
255 # ensure we get a ".TH" as viewers require it.
256 self.head.append(self.header())
256 self.head.append(self.header())
257 # filter body
257 # filter body
258 for i in xrange(len(self.body)-1, 0, -1):
258 for i in xrange(len(self.body)-1, 0, -1):
259 # remove superfluous vertical gaps.
259 # remove superfluous vertical gaps.
260 if self.body[i] == '.sp\n':
260 if self.body[i] == '.sp\n':
261 if self.body[i - 1][:4] in ('.BI ','.IP '):
261 if self.body[i - 1][:4] in ('.BI ','.IP '):
262 self.body[i] = '.\n'
262 self.body[i] = '.\n'
263 elif (self.body[i - 1][:3] == '.B ' and
263 elif (self.body[i - 1][:3] == '.B ' and
264 self.body[i - 2][:4] == '.TP\n'):
264 self.body[i - 2][:4] == '.TP\n'):
265 self.body[i] = '.\n'
265 self.body[i] = '.\n'
266 elif (self.body[i - 1] == '\n' and
266 elif (self.body[i - 1] == '\n' and
267 self.body[i - 2][0] != '.' and
267 self.body[i - 2][0] != '.' and
268 (self.body[i - 3][:7] == '.TP\n.B '
268 (self.body[i - 3][:7] == '.TP\n.B '
269 or self.body[i - 3][:4] == '\n.B ')
269 or self.body[i - 3][:4] == '\n.B ')
270 ):
270 ):
271 self.body[i] = '.\n'
271 self.body[i] = '.\n'
272 return ''.join(self.head + self.body + self.foot)
272 return ''.join(self.head + self.body + self.foot)
273
273
274 def deunicode(self, text):
274 def deunicode(self, text):
275 text = text.replace(u'\xa0', '\\ ')
275 text = text.replace(u'\xa0', '\\ ')
276 text = text.replace(u'\u2020', '\\(dg')
276 text = text.replace(u'\u2020', '\\(dg')
277 return text
277 return text
278
278
279 def visit_Text(self, node):
279 def visit_Text(self, node):
280 text = node.astext()
280 text = node.astext()
281 text = text.replace('\\','\\e')
281 text = text.replace('\\','\\e')
282 replace_pairs = [
282 replace_pairs = [
283 (u'-', ur'\-'),
283 (u'-', ur'\-'),
284 (u'\'', ur'\(aq'),
284 (u'\'', ur'\(aq'),
285 (u'Β΄', ur'\''),
285 (u'Β΄', ur'\''),
286 (u'`', ur'\(ga'),
286 (u'`', ur'\(ga'),
287 ]
287 ]
288 for (in_char, out_markup) in replace_pairs:
288 for (in_char, out_markup) in replace_pairs:
289 text = text.replace(in_char, out_markup)
289 text = text.replace(in_char, out_markup)
290 # unicode
290 # unicode
291 text = self.deunicode(text)
291 text = self.deunicode(text)
292 if self._in_literal:
292 if self._in_literal:
293 # prevent interpretation of "." at line start
293 # prevent interpretation of "." at line start
294 if text[0] == '.':
294 if text[0] == '.':
295 text = '\\&' + text
295 text = '\\&' + text
296 text = text.replace('\n.', '\n\\&.')
296 text = text.replace('\n.', '\n\\&.')
297 self.body.append(text)
297 self.body.append(text)
298
298
299 def depart_Text(self, node):
299 def depart_Text(self, node):
300 pass
300 pass
301
301
302 def list_start(self, node):
302 def list_start(self, node):
303 class enum_char:
303 class enum_char(object):
304 enum_style = {
304 enum_style = {
305 'bullet' : '\\(bu',
305 'bullet' : '\\(bu',
306 'emdash' : '\\(em',
306 'emdash' : '\\(em',
307 }
307 }
308
308
309 def __init__(self, style):
309 def __init__(self, style):
310 self._style = style
310 self._style = style
311 if 'start' in node:
311 if 'start' in node:
312 self._cnt = node['start'] - 1
312 self._cnt = node['start'] - 1
313 else:
313 else:
314 self._cnt = 0
314 self._cnt = 0
315 self._indent = 2
315 self._indent = 2
316 if style == 'arabic':
316 if style == 'arabic':
317 # indentation depends on number of childrens
317 # indentation depends on number of childrens
318 # and start value.
318 # and start value.
319 self._indent = len(str(len(node.children)))
319 self._indent = len(str(len(node.children)))
320 self._indent += len(str(self._cnt)) + 1
320 self._indent += len(str(self._cnt)) + 1
321 elif style == 'loweralpha':
321 elif style == 'loweralpha':
322 self._cnt += ord('a') - 1
322 self._cnt += ord('a') - 1
323 self._indent = 3
323 self._indent = 3
324 elif style == 'upperalpha':
324 elif style == 'upperalpha':
325 self._cnt += ord('A') - 1
325 self._cnt += ord('A') - 1
326 self._indent = 3
326 self._indent = 3
327 elif style.endswith('roman'):
327 elif style.endswith('roman'):
328 self._indent = 5
328 self._indent = 5
329
329
330 def next(self):
330 def next(self):
331 if self._style == 'bullet':
331 if self._style == 'bullet':
332 return self.enum_style[self._style]
332 return self.enum_style[self._style]
333 elif self._style == 'emdash':
333 elif self._style == 'emdash':
334 return self.enum_style[self._style]
334 return self.enum_style[self._style]
335 self._cnt += 1
335 self._cnt += 1
336 # TODO add prefix postfix
336 # TODO add prefix postfix
337 if self._style == 'arabic':
337 if self._style == 'arabic':
338 return "%d." % self._cnt
338 return "%d." % self._cnt
339 elif self._style in ('loweralpha', 'upperalpha'):
339 elif self._style in ('loweralpha', 'upperalpha'):
340 return "%c." % self._cnt
340 return "%c." % self._cnt
341 elif self._style.endswith('roman'):
341 elif self._style.endswith('roman'):
342 res = roman.toRoman(self._cnt) + '.'
342 res = roman.toRoman(self._cnt) + '.'
343 if self._style.startswith('upper'):
343 if self._style.startswith('upper'):
344 return res.upper()
344 return res.upper()
345 return res.lower()
345 return res.lower()
346 else:
346 else:
347 return "%d." % self._cnt
347 return "%d." % self._cnt
348 def get_width(self):
348 def get_width(self):
349 return self._indent
349 return self._indent
350 def __repr__(self):
350 def __repr__(self):
351 return 'enum_style-%s' % list(self._style)
351 return 'enum_style-%s' % list(self._style)
352
352
353 if 'enumtype' in node:
353 if 'enumtype' in node:
354 self._list_char.append(enum_char(node['enumtype']))
354 self._list_char.append(enum_char(node['enumtype']))
355 else:
355 else:
356 self._list_char.append(enum_char('bullet'))
356 self._list_char.append(enum_char('bullet'))
357 if len(self._list_char) > 1:
357 if len(self._list_char) > 1:
358 # indent nested lists
358 # indent nested lists
359 self.indent(self._list_char[-2].get_width())
359 self.indent(self._list_char[-2].get_width())
360 else:
360 else:
361 self.indent(self._list_char[-1].get_width())
361 self.indent(self._list_char[-1].get_width())
362
362
363 def list_end(self):
363 def list_end(self):
364 self.dedent()
364 self.dedent()
365 self._list_char.pop()
365 self._list_char.pop()
366
366
367 def header(self):
367 def header(self):
368 tmpl = (".TH %(title_upper)s %(manual_section)s"
368 tmpl = (".TH %(title_upper)s %(manual_section)s"
369 " \"%(date)s\" \"%(version)s\" \"%(manual_group)s\"\n"
369 " \"%(date)s\" \"%(version)s\" \"%(manual_group)s\"\n"
370 ".SH NAME\n"
370 ".SH NAME\n"
371 "%(title)s \- %(subtitle)s\n")
371 "%(title)s \- %(subtitle)s\n")
372 return tmpl % self._docinfo
372 return tmpl % self._docinfo
373
373
374 def append_header(self):
374 def append_header(self):
375 """append header with .TH and .SH NAME"""
375 """append header with .TH and .SH NAME"""
376 # NOTE before everything
376 # NOTE before everything
377 # .TH title_upper section date source manual
377 # .TH title_upper section date source manual
378 if self.header_written:
378 if self.header_written:
379 return
379 return
380 self.body.append(self.header())
380 self.body.append(self.header())
381 self.body.append(MACRO_DEF)
381 self.body.append(MACRO_DEF)
382 self.header_written = 1
382 self.header_written = 1
383
383
384 def visit_address(self, node):
384 def visit_address(self, node):
385 self.visit_docinfo_item(node, 'address')
385 self.visit_docinfo_item(node, 'address')
386
386
387 def depart_address(self, node):
387 def depart_address(self, node):
388 pass
388 pass
389
389
390 def visit_admonition(self, node, name=None):
390 def visit_admonition(self, node, name=None):
391 if name:
391 if name:
392 self.body.append('.IP %s\n' %
392 self.body.append('.IP %s\n' %
393 self.language.labels.get(name, name))
393 self.language.labels.get(name, name))
394
394
395 def depart_admonition(self, node):
395 def depart_admonition(self, node):
396 self.body.append('.RE\n')
396 self.body.append('.RE\n')
397
397
398 def visit_attention(self, node):
398 def visit_attention(self, node):
399 self.visit_admonition(node, 'attention')
399 self.visit_admonition(node, 'attention')
400
400
401 depart_attention = depart_admonition
401 depart_attention = depart_admonition
402
402
403 def visit_docinfo_item(self, node, name):
403 def visit_docinfo_item(self, node, name):
404 if name == 'author':
404 if name == 'author':
405 self._docinfo[name].append(node.astext())
405 self._docinfo[name].append(node.astext())
406 else:
406 else:
407 self._docinfo[name] = node.astext()
407 self._docinfo[name] = node.astext()
408 self._docinfo_keys.append(name)
408 self._docinfo_keys.append(name)
409 raise nodes.SkipNode
409 raise nodes.SkipNode
410
410
411 def depart_docinfo_item(self, node):
411 def depart_docinfo_item(self, node):
412 pass
412 pass
413
413
414 def visit_author(self, node):
414 def visit_author(self, node):
415 self.visit_docinfo_item(node, 'author')
415 self.visit_docinfo_item(node, 'author')
416
416
417 depart_author = depart_docinfo_item
417 depart_author = depart_docinfo_item
418
418
419 def visit_authors(self, node):
419 def visit_authors(self, node):
420 # _author is called anyway.
420 # _author is called anyway.
421 pass
421 pass
422
422
423 def depart_authors(self, node):
423 def depart_authors(self, node):
424 pass
424 pass
425
425
426 def visit_block_quote(self, node):
426 def visit_block_quote(self, node):
427 # BUG/HACK: indent alway uses the _last_ indention,
427 # BUG/HACK: indent alway uses the _last_ indention,
428 # thus we need two of them.
428 # thus we need two of them.
429 self.indent(BLOCKQOUTE_INDENT)
429 self.indent(BLOCKQOUTE_INDENT)
430 self.indent(0)
430 self.indent(0)
431
431
432 def depart_block_quote(self, node):
432 def depart_block_quote(self, node):
433 self.dedent()
433 self.dedent()
434 self.dedent()
434 self.dedent()
435
435
436 def visit_bullet_list(self, node):
436 def visit_bullet_list(self, node):
437 self.list_start(node)
437 self.list_start(node)
438
438
439 def depart_bullet_list(self, node):
439 def depart_bullet_list(self, node):
440 self.list_end()
440 self.list_end()
441
441
442 def visit_caption(self, node):
442 def visit_caption(self, node):
443 pass
443 pass
444
444
445 def depart_caption(self, node):
445 def depart_caption(self, node):
446 pass
446 pass
447
447
448 def visit_caution(self, node):
448 def visit_caution(self, node):
449 self.visit_admonition(node, 'caution')
449 self.visit_admonition(node, 'caution')
450
450
451 depart_caution = depart_admonition
451 depart_caution = depart_admonition
452
452
453 def visit_citation(self, node):
453 def visit_citation(self, node):
454 num, text = node.astext().split(None, 1)
454 num, text = node.astext().split(None, 1)
455 num = num.strip()
455 num = num.strip()
456 self.body.append('.IP [%s] 5\n' % num)
456 self.body.append('.IP [%s] 5\n' % num)
457
457
458 def depart_citation(self, node):
458 def depart_citation(self, node):
459 pass
459 pass
460
460
461 def visit_citation_reference(self, node):
461 def visit_citation_reference(self, node):
462 self.body.append('['+node.astext()+']')
462 self.body.append('['+node.astext()+']')
463 raise nodes.SkipNode
463 raise nodes.SkipNode
464
464
465 def visit_classifier(self, node):
465 def visit_classifier(self, node):
466 pass
466 pass
467
467
468 def depart_classifier(self, node):
468 def depart_classifier(self, node):
469 pass
469 pass
470
470
471 def visit_colspec(self, node):
471 def visit_colspec(self, node):
472 self.colspecs.append(node)
472 self.colspecs.append(node)
473
473
474 def depart_colspec(self, node):
474 def depart_colspec(self, node):
475 pass
475 pass
476
476
477 def write_colspecs(self):
477 def write_colspecs(self):
478 self.body.append("%s.\n" % ('L '*len(self.colspecs)))
478 self.body.append("%s.\n" % ('L '*len(self.colspecs)))
479
479
480 def visit_comment(self, node,
480 def visit_comment(self, node,
481 sub=re.compile('-(?=-)').sub):
481 sub=re.compile('-(?=-)').sub):
482 self.body.append(self.comment(node.astext()))
482 self.body.append(self.comment(node.astext()))
483 raise nodes.SkipNode
483 raise nodes.SkipNode
484
484
485 def visit_contact(self, node):
485 def visit_contact(self, node):
486 self.visit_docinfo_item(node, 'contact')
486 self.visit_docinfo_item(node, 'contact')
487
487
488 depart_contact = depart_docinfo_item
488 depart_contact = depart_docinfo_item
489
489
490 def visit_container(self, node):
490 def visit_container(self, node):
491 pass
491 pass
492
492
493 def depart_container(self, node):
493 def depart_container(self, node):
494 pass
494 pass
495
495
496 def visit_compound(self, node):
496 def visit_compound(self, node):
497 pass
497 pass
498
498
499 def depart_compound(self, node):
499 def depart_compound(self, node):
500 pass
500 pass
501
501
502 def visit_copyright(self, node):
502 def visit_copyright(self, node):
503 self.visit_docinfo_item(node, 'copyright')
503 self.visit_docinfo_item(node, 'copyright')
504
504
505 def visit_danger(self, node):
505 def visit_danger(self, node):
506 self.visit_admonition(node, 'danger')
506 self.visit_admonition(node, 'danger')
507
507
508 depart_danger = depart_admonition
508 depart_danger = depart_admonition
509
509
510 def visit_date(self, node):
510 def visit_date(self, node):
511 self.visit_docinfo_item(node, 'date')
511 self.visit_docinfo_item(node, 'date')
512
512
513 def visit_decoration(self, node):
513 def visit_decoration(self, node):
514 pass
514 pass
515
515
516 def depart_decoration(self, node):
516 def depart_decoration(self, node):
517 pass
517 pass
518
518
519 def visit_definition(self, node):
519 def visit_definition(self, node):
520 pass
520 pass
521
521
522 def depart_definition(self, node):
522 def depart_definition(self, node):
523 pass
523 pass
524
524
525 def visit_definition_list(self, node):
525 def visit_definition_list(self, node):
526 self.indent(DEFINITION_LIST_INDENT)
526 self.indent(DEFINITION_LIST_INDENT)
527
527
528 def depart_definition_list(self, node):
528 def depart_definition_list(self, node):
529 self.dedent()
529 self.dedent()
530
530
531 def visit_definition_list_item(self, node):
531 def visit_definition_list_item(self, node):
532 self.body.append(self.defs['definition_list_item'][0])
532 self.body.append(self.defs['definition_list_item'][0])
533
533
534 def depart_definition_list_item(self, node):
534 def depart_definition_list_item(self, node):
535 self.body.append(self.defs['definition_list_item'][1])
535 self.body.append(self.defs['definition_list_item'][1])
536
536
537 def visit_description(self, node):
537 def visit_description(self, node):
538 pass
538 pass
539
539
540 def depart_description(self, node):
540 def depart_description(self, node):
541 pass
541 pass
542
542
543 def visit_docinfo(self, node):
543 def visit_docinfo(self, node):
544 self._in_docinfo = 1
544 self._in_docinfo = 1
545
545
546 def depart_docinfo(self, node):
546 def depart_docinfo(self, node):
547 self._in_docinfo = None
547 self._in_docinfo = None
548 # NOTE nothing should be written before this
548 # NOTE nothing should be written before this
549 self.append_header()
549 self.append_header()
550
550
551 def visit_doctest_block(self, node):
551 def visit_doctest_block(self, node):
552 self.body.append(self.defs['literal_block'][0])
552 self.body.append(self.defs['literal_block'][0])
553 self._in_literal = True
553 self._in_literal = True
554
554
555 def depart_doctest_block(self, node):
555 def depart_doctest_block(self, node):
556 self._in_literal = False
556 self._in_literal = False
557 self.body.append(self.defs['literal_block'][1])
557 self.body.append(self.defs['literal_block'][1])
558
558
559 def visit_document(self, node):
559 def visit_document(self, node):
560 # no blank line between comment and header.
560 # no blank line between comment and header.
561 self.body.append(self.comment(self.document_start).rstrip()+'\n')
561 self.body.append(self.comment(self.document_start).rstrip()+'\n')
562 # writing header is postboned
562 # writing header is postboned
563 self.header_written = 0
563 self.header_written = 0
564
564
565 def depart_document(self, node):
565 def depart_document(self, node):
566 if self._docinfo['author']:
566 if self._docinfo['author']:
567 self.body.append('.SH AUTHOR\n%s\n'
567 self.body.append('.SH AUTHOR\n%s\n'
568 % ', '.join(self._docinfo['author']))
568 % ', '.join(self._docinfo['author']))
569 skip = ('author', 'copyright', 'date',
569 skip = ('author', 'copyright', 'date',
570 'manual_group', 'manual_section',
570 'manual_group', 'manual_section',
571 'subtitle',
571 'subtitle',
572 'title', 'title_upper', 'version')
572 'title', 'title_upper', 'version')
573 for name in self._docinfo_keys:
573 for name in self._docinfo_keys:
574 if name == 'address':
574 if name == 'address':
575 self.body.append("\n%s:\n%s%s.nf\n%s\n.fi\n%s%s" % (
575 self.body.append("\n%s:\n%s%s.nf\n%s\n.fi\n%s%s" % (
576 self.language.labels.get(name, name),
576 self.language.labels.get(name, name),
577 self.defs['indent'][0] % 0,
577 self.defs['indent'][0] % 0,
578 self.defs['indent'][0] % BLOCKQOUTE_INDENT,
578 self.defs['indent'][0] % BLOCKQOUTE_INDENT,
579 self._docinfo[name],
579 self._docinfo[name],
580 self.defs['indent'][1],
580 self.defs['indent'][1],
581 self.defs['indent'][1]))
581 self.defs['indent'][1]))
582 elif not name in skip:
582 elif not name in skip:
583 if name in self._docinfo_names:
583 if name in self._docinfo_names:
584 label = self._docinfo_names[name]
584 label = self._docinfo_names[name]
585 else:
585 else:
586 label = self.language.labels.get(name, name)
586 label = self.language.labels.get(name, name)
587 self.body.append("\n%s: %s\n" % (label, self._docinfo[name]))
587 self.body.append("\n%s: %s\n" % (label, self._docinfo[name]))
588 if self._docinfo['copyright']:
588 if self._docinfo['copyright']:
589 self.body.append('.SH COPYRIGHT\n%s\n'
589 self.body.append('.SH COPYRIGHT\n%s\n'
590 % self._docinfo['copyright'])
590 % self._docinfo['copyright'])
591 self.body.append(self.comment(
591 self.body.append(self.comment(
592 'Generated by docutils manpage writer.\n'))
592 'Generated by docutils manpage writer.\n'))
593
593
594 def visit_emphasis(self, node):
594 def visit_emphasis(self, node):
595 self.body.append(self.defs['emphasis'][0])
595 self.body.append(self.defs['emphasis'][0])
596
596
597 def depart_emphasis(self, node):
597 def depart_emphasis(self, node):
598 self.body.append(self.defs['emphasis'][1])
598 self.body.append(self.defs['emphasis'][1])
599
599
600 def visit_entry(self, node):
600 def visit_entry(self, node):
601 # a cell in a table row
601 # a cell in a table row
602 if 'morerows' in node:
602 if 'morerows' in node:
603 self.document.reporter.warning('"table row spanning" not supported',
603 self.document.reporter.warning('"table row spanning" not supported',
604 base_node=node)
604 base_node=node)
605 if 'morecols' in node:
605 if 'morecols' in node:
606 self.document.reporter.warning(
606 self.document.reporter.warning(
607 '"table cell spanning" not supported', base_node=node)
607 '"table cell spanning" not supported', base_node=node)
608 self.context.append(len(self.body))
608 self.context.append(len(self.body))
609
609
610 def depart_entry(self, node):
610 def depart_entry(self, node):
611 start = self.context.pop()
611 start = self.context.pop()
612 self._active_table.append_cell(self.body[start:])
612 self._active_table.append_cell(self.body[start:])
613 del self.body[start:]
613 del self.body[start:]
614
614
615 def visit_enumerated_list(self, node):
615 def visit_enumerated_list(self, node):
616 self.list_start(node)
616 self.list_start(node)
617
617
618 def depart_enumerated_list(self, node):
618 def depart_enumerated_list(self, node):
619 self.list_end()
619 self.list_end()
620
620
621 def visit_error(self, node):
621 def visit_error(self, node):
622 self.visit_admonition(node, 'error')
622 self.visit_admonition(node, 'error')
623
623
624 depart_error = depart_admonition
624 depart_error = depart_admonition
625
625
626 def visit_field(self, node):
626 def visit_field(self, node):
627 pass
627 pass
628
628
629 def depart_field(self, node):
629 def depart_field(self, node):
630 pass
630 pass
631
631
632 def visit_field_body(self, node):
632 def visit_field_body(self, node):
633 if self._in_docinfo:
633 if self._in_docinfo:
634 name_normalized = self._field_name.lower().replace(" ","_")
634 name_normalized = self._field_name.lower().replace(" ","_")
635 self._docinfo_names[name_normalized] = self._field_name
635 self._docinfo_names[name_normalized] = self._field_name
636 self.visit_docinfo_item(node, name_normalized)
636 self.visit_docinfo_item(node, name_normalized)
637 raise nodes.SkipNode
637 raise nodes.SkipNode
638
638
639 def depart_field_body(self, node):
639 def depart_field_body(self, node):
640 pass
640 pass
641
641
642 def visit_field_list(self, node):
642 def visit_field_list(self, node):
643 self.indent(FIELD_LIST_INDENT)
643 self.indent(FIELD_LIST_INDENT)
644
644
645 def depart_field_list(self, node):
645 def depart_field_list(self, node):
646 self.dedent()
646 self.dedent()
647
647
648 def visit_field_name(self, node):
648 def visit_field_name(self, node):
649 if self._in_docinfo:
649 if self._in_docinfo:
650 self._field_name = node.astext()
650 self._field_name = node.astext()
651 raise nodes.SkipNode
651 raise nodes.SkipNode
652 else:
652 else:
653 self.body.append(self.defs['field_name'][0])
653 self.body.append(self.defs['field_name'][0])
654
654
655 def depart_field_name(self, node):
655 def depart_field_name(self, node):
656 self.body.append(self.defs['field_name'][1])
656 self.body.append(self.defs['field_name'][1])
657
657
658 def visit_figure(self, node):
658 def visit_figure(self, node):
659 self.indent(2.5)
659 self.indent(2.5)
660 self.indent(0)
660 self.indent(0)
661
661
662 def depart_figure(self, node):
662 def depart_figure(self, node):
663 self.dedent()
663 self.dedent()
664 self.dedent()
664 self.dedent()
665
665
666 def visit_footer(self, node):
666 def visit_footer(self, node):
667 self.document.reporter.warning('"footer" not supported',
667 self.document.reporter.warning('"footer" not supported',
668 base_node=node)
668 base_node=node)
669
669
670 def depart_footer(self, node):
670 def depart_footer(self, node):
671 pass
671 pass
672
672
673 def visit_footnote(self, node):
673 def visit_footnote(self, node):
674 num, text = node.astext().split(None, 1)
674 num, text = node.astext().split(None, 1)
675 num = num.strip()
675 num = num.strip()
676 self.body.append('.IP [%s] 5\n' % self.deunicode(num))
676 self.body.append('.IP [%s] 5\n' % self.deunicode(num))
677
677
678 def depart_footnote(self, node):
678 def depart_footnote(self, node):
679 pass
679 pass
680
680
681 def footnote_backrefs(self, node):
681 def footnote_backrefs(self, node):
682 self.document.reporter.warning('"footnote_backrefs" not supported',
682 self.document.reporter.warning('"footnote_backrefs" not supported',
683 base_node=node)
683 base_node=node)
684
684
685 def visit_footnote_reference(self, node):
685 def visit_footnote_reference(self, node):
686 self.body.append('['+self.deunicode(node.astext())+']')
686 self.body.append('['+self.deunicode(node.astext())+']')
687 raise nodes.SkipNode
687 raise nodes.SkipNode
688
688
689 def depart_footnote_reference(self, node):
689 def depart_footnote_reference(self, node):
690 pass
690 pass
691
691
692 def visit_generated(self, node):
692 def visit_generated(self, node):
693 pass
693 pass
694
694
695 def depart_generated(self, node):
695 def depart_generated(self, node):
696 pass
696 pass
697
697
698 def visit_header(self, node):
698 def visit_header(self, node):
699 raise NotImplementedError, node.astext()
699 raise NotImplementedError, node.astext()
700
700
701 def depart_header(self, node):
701 def depart_header(self, node):
702 pass
702 pass
703
703
704 def visit_hint(self, node):
704 def visit_hint(self, node):
705 self.visit_admonition(node, 'hint')
705 self.visit_admonition(node, 'hint')
706
706
707 depart_hint = depart_admonition
707 depart_hint = depart_admonition
708
708
709 def visit_subscript(self, node):
709 def visit_subscript(self, node):
710 self.body.append('\\s-2\\d')
710 self.body.append('\\s-2\\d')
711
711
712 def depart_subscript(self, node):
712 def depart_subscript(self, node):
713 self.body.append('\\u\\s0')
713 self.body.append('\\u\\s0')
714
714
715 def visit_superscript(self, node):
715 def visit_superscript(self, node):
716 self.body.append('\\s-2\\u')
716 self.body.append('\\s-2\\u')
717
717
718 def depart_superscript(self, node):
718 def depart_superscript(self, node):
719 self.body.append('\\d\\s0')
719 self.body.append('\\d\\s0')
720
720
721 def visit_attribution(self, node):
721 def visit_attribution(self, node):
722 self.body.append('\\(em ')
722 self.body.append('\\(em ')
723
723
724 def depart_attribution(self, node):
724 def depart_attribution(self, node):
725 self.body.append('\n')
725 self.body.append('\n')
726
726
727 def visit_image(self, node):
727 def visit_image(self, node):
728 self.document.reporter.warning('"image" not supported',
728 self.document.reporter.warning('"image" not supported',
729 base_node=node)
729 base_node=node)
730 text = []
730 text = []
731 if 'alt' in node.attributes:
731 if 'alt' in node.attributes:
732 text.append(node.attributes['alt'])
732 text.append(node.attributes['alt'])
733 if 'uri' in node.attributes:
733 if 'uri' in node.attributes:
734 text.append(node.attributes['uri'])
734 text.append(node.attributes['uri'])
735 self.body.append('[image: %s]\n' % ('/'.join(text)))
735 self.body.append('[image: %s]\n' % ('/'.join(text)))
736 raise nodes.SkipNode
736 raise nodes.SkipNode
737
737
738 def visit_important(self, node):
738 def visit_important(self, node):
739 self.visit_admonition(node, 'important')
739 self.visit_admonition(node, 'important')
740
740
741 depart_important = depart_admonition
741 depart_important = depart_admonition
742
742
743 def visit_label(self, node):
743 def visit_label(self, node):
744 # footnote and citation
744 # footnote and citation
745 if (isinstance(node.parent, nodes.footnote)
745 if (isinstance(node.parent, nodes.footnote)
746 or isinstance(node.parent, nodes.citation)):
746 or isinstance(node.parent, nodes.citation)):
747 raise nodes.SkipNode
747 raise nodes.SkipNode
748 self.document.reporter.warning('"unsupported "label"',
748 self.document.reporter.warning('"unsupported "label"',
749 base_node=node)
749 base_node=node)
750 self.body.append('[')
750 self.body.append('[')
751
751
752 def depart_label(self, node):
752 def depart_label(self, node):
753 self.body.append(']\n')
753 self.body.append(']\n')
754
754
755 def visit_legend(self, node):
755 def visit_legend(self, node):
756 pass
756 pass
757
757
758 def depart_legend(self, node):
758 def depart_legend(self, node):
759 pass
759 pass
760
760
761 # WHAT should we use .INDENT, .UNINDENT ?
761 # WHAT should we use .INDENT, .UNINDENT ?
762 def visit_line_block(self, node):
762 def visit_line_block(self, node):
763 self._line_block += 1
763 self._line_block += 1
764 if self._line_block == 1:
764 if self._line_block == 1:
765 self.body.append('.sp\n')
765 self.body.append('.sp\n')
766 self.body.append('.nf\n')
766 self.body.append('.nf\n')
767 else:
767 else:
768 self.body.append('.in +2\n')
768 self.body.append('.in +2\n')
769
769
770 def depart_line_block(self, node):
770 def depart_line_block(self, node):
771 self._line_block -= 1
771 self._line_block -= 1
772 if self._line_block == 0:
772 if self._line_block == 0:
773 self.body.append('.fi\n')
773 self.body.append('.fi\n')
774 self.body.append('.sp\n')
774 self.body.append('.sp\n')
775 else:
775 else:
776 self.body.append('.in -2\n')
776 self.body.append('.in -2\n')
777
777
778 def visit_line(self, node):
778 def visit_line(self, node):
779 pass
779 pass
780
780
781 def depart_line(self, node):
781 def depart_line(self, node):
782 self.body.append('\n')
782 self.body.append('\n')
783
783
784 def visit_list_item(self, node):
784 def visit_list_item(self, node):
785 # man 7 man argues to use ".IP" instead of ".TP"
785 # man 7 man argues to use ".IP" instead of ".TP"
786 self.body.append('.IP %s %d\n' % (
786 self.body.append('.IP %s %d\n' % (
787 self._list_char[-1].next(),
787 self._list_char[-1].next(),
788 self._list_char[-1].get_width(),))
788 self._list_char[-1].get_width(),))
789
789
790 def depart_list_item(self, node):
790 def depart_list_item(self, node):
791 pass
791 pass
792
792
793 def visit_literal(self, node):
793 def visit_literal(self, node):
794 self.body.append(self.defs['literal'][0])
794 self.body.append(self.defs['literal'][0])
795
795
796 def depart_literal(self, node):
796 def depart_literal(self, node):
797 self.body.append(self.defs['literal'][1])
797 self.body.append(self.defs['literal'][1])
798
798
799 def visit_literal_block(self, node):
799 def visit_literal_block(self, node):
800 self.body.append(self.defs['literal_block'][0])
800 self.body.append(self.defs['literal_block'][0])
801 self._in_literal = True
801 self._in_literal = True
802
802
803 def depart_literal_block(self, node):
803 def depart_literal_block(self, node):
804 self._in_literal = False
804 self._in_literal = False
805 self.body.append(self.defs['literal_block'][1])
805 self.body.append(self.defs['literal_block'][1])
806
806
807 def visit_meta(self, node):
807 def visit_meta(self, node):
808 raise NotImplementedError, node.astext()
808 raise NotImplementedError, node.astext()
809
809
810 def depart_meta(self, node):
810 def depart_meta(self, node):
811 pass
811 pass
812
812
813 def visit_note(self, node):
813 def visit_note(self, node):
814 self.visit_admonition(node, 'note')
814 self.visit_admonition(node, 'note')
815
815
816 depart_note = depart_admonition
816 depart_note = depart_admonition
817
817
818 def indent(self, by=0.5):
818 def indent(self, by=0.5):
819 # if we are in a section ".SH" there already is a .RS
819 # if we are in a section ".SH" there already is a .RS
820 step = self._indent[-1]
820 step = self._indent[-1]
821 self._indent.append(by)
821 self._indent.append(by)
822 self.body.append(self.defs['indent'][0] % step)
822 self.body.append(self.defs['indent'][0] % step)
823
823
824 def dedent(self):
824 def dedent(self):
825 self._indent.pop()
825 self._indent.pop()
826 self.body.append(self.defs['indent'][1])
826 self.body.append(self.defs['indent'][1])
827
827
828 def visit_option_list(self, node):
828 def visit_option_list(self, node):
829 self.indent(OPTION_LIST_INDENT)
829 self.indent(OPTION_LIST_INDENT)
830
830
831 def depart_option_list(self, node):
831 def depart_option_list(self, node):
832 self.dedent()
832 self.dedent()
833
833
834 def visit_option_list_item(self, node):
834 def visit_option_list_item(self, node):
835 # one item of the list
835 # one item of the list
836 self.body.append(self.defs['option_list_item'][0])
836 self.body.append(self.defs['option_list_item'][0])
837
837
838 def depart_option_list_item(self, node):
838 def depart_option_list_item(self, node):
839 self.body.append(self.defs['option_list_item'][1])
839 self.body.append(self.defs['option_list_item'][1])
840
840
841 def visit_option_group(self, node):
841 def visit_option_group(self, node):
842 # as one option could have several forms it is a group
842 # as one option could have several forms it is a group
843 # options without parameter bold only, .B, -v
843 # options without parameter bold only, .B, -v
844 # options with parameter bold italic, .BI, -f file
844 # options with parameter bold italic, .BI, -f file
845 #
845 #
846 # we do not know if .B or .BI
846 # we do not know if .B or .BI
847 self.context.append('.B') # blind guess
847 self.context.append('.B') # blind guess
848 self.context.append(len(self.body)) # to be able to insert later
848 self.context.append(len(self.body)) # to be able to insert later
849 self.context.append(0) # option counter
849 self.context.append(0) # option counter
850
850
851 def depart_option_group(self, node):
851 def depart_option_group(self, node):
852 self.context.pop() # the counter
852 self.context.pop() # the counter
853 start_position = self.context.pop()
853 start_position = self.context.pop()
854 text = self.body[start_position:]
854 text = self.body[start_position:]
855 del self.body[start_position:]
855 del self.body[start_position:]
856 self.body.append('%s%s\n' % (self.context.pop(), ''.join(text)))
856 self.body.append('%s%s\n' % (self.context.pop(), ''.join(text)))
857
857
858 def visit_option(self, node):
858 def visit_option(self, node):
859 # each form of the option will be presented separately
859 # each form of the option will be presented separately
860 if self.context[-1] > 0:
860 if self.context[-1] > 0:
861 self.body.append(', ')
861 self.body.append(', ')
862 if self.context[-3] == '.BI':
862 if self.context[-3] == '.BI':
863 self.body.append('\\')
863 self.body.append('\\')
864 self.body.append(' ')
864 self.body.append(' ')
865
865
866 def depart_option(self, node):
866 def depart_option(self, node):
867 self.context[-1] += 1
867 self.context[-1] += 1
868
868
869 def visit_option_string(self, node):
869 def visit_option_string(self, node):
870 # do not know if .B or .BI
870 # do not know if .B or .BI
871 pass
871 pass
872
872
873 def depart_option_string(self, node):
873 def depart_option_string(self, node):
874 pass
874 pass
875
875
876 def visit_option_argument(self, node):
876 def visit_option_argument(self, node):
877 self.context[-3] = '.BI' # bold/italic alternate
877 self.context[-3] = '.BI' # bold/italic alternate
878 if node['delimiter'] != ' ':
878 if node['delimiter'] != ' ':
879 self.body.append('\\fB%s ' % node['delimiter'])
879 self.body.append('\\fB%s ' % node['delimiter'])
880 elif self.body[len(self.body)-1].endswith('='):
880 elif self.body[len(self.body)-1].endswith('='):
881 # a blank only means no blank in output, just changing font
881 # a blank only means no blank in output, just changing font
882 self.body.append(' ')
882 self.body.append(' ')
883 else:
883 else:
884 # blank backslash blank, switch font then a blank
884 # blank backslash blank, switch font then a blank
885 self.body.append(' \\ ')
885 self.body.append(' \\ ')
886
886
887 def depart_option_argument(self, node):
887 def depart_option_argument(self, node):
888 pass
888 pass
889
889
890 def visit_organization(self, node):
890 def visit_organization(self, node):
891 self.visit_docinfo_item(node, 'organization')
891 self.visit_docinfo_item(node, 'organization')
892
892
893 def depart_organization(self, node):
893 def depart_organization(self, node):
894 pass
894 pass
895
895
896 def visit_paragraph(self, node):
896 def visit_paragraph(self, node):
897 # ``.PP`` : Start standard indented paragraph.
897 # ``.PP`` : Start standard indented paragraph.
898 # ``.LP`` : Start block paragraph, all except the first.
898 # ``.LP`` : Start block paragraph, all except the first.
899 # ``.P [type]`` : Start paragraph type.
899 # ``.P [type]`` : Start paragraph type.
900 # NOTE dont use paragraph starts because they reset indentation.
900 # NOTE dont use paragraph starts because they reset indentation.
901 # ``.sp`` is only vertical space
901 # ``.sp`` is only vertical space
902 self.ensure_eol()
902 self.ensure_eol()
903 self.body.append('.sp\n')
903 self.body.append('.sp\n')
904
904
905 def depart_paragraph(self, node):
905 def depart_paragraph(self, node):
906 self.body.append('\n')
906 self.body.append('\n')
907
907
908 def visit_problematic(self, node):
908 def visit_problematic(self, node):
909 self.body.append(self.defs['problematic'][0])
909 self.body.append(self.defs['problematic'][0])
910
910
911 def depart_problematic(self, node):
911 def depart_problematic(self, node):
912 self.body.append(self.defs['problematic'][1])
912 self.body.append(self.defs['problematic'][1])
913
913
914 def visit_raw(self, node):
914 def visit_raw(self, node):
915 if node.get('format') == 'manpage':
915 if node.get('format') == 'manpage':
916 self.body.append(node.astext() + "\n")
916 self.body.append(node.astext() + "\n")
917 # Keep non-manpage raw text out of output:
917 # Keep non-manpage raw text out of output:
918 raise nodes.SkipNode
918 raise nodes.SkipNode
919
919
920 def visit_reference(self, node):
920 def visit_reference(self, node):
921 """E.g. link or email address."""
921 """E.g. link or email address."""
922 self.body.append(self.defs['reference'][0])
922 self.body.append(self.defs['reference'][0])
923
923
924 def depart_reference(self, node):
924 def depart_reference(self, node):
925 self.body.append(self.defs['reference'][1])
925 self.body.append(self.defs['reference'][1])
926
926
927 def visit_revision(self, node):
927 def visit_revision(self, node):
928 self.visit_docinfo_item(node, 'revision')
928 self.visit_docinfo_item(node, 'revision')
929
929
930 depart_revision = depart_docinfo_item
930 depart_revision = depart_docinfo_item
931
931
932 def visit_row(self, node):
932 def visit_row(self, node):
933 self._active_table.new_row()
933 self._active_table.new_row()
934
934
935 def depart_row(self, node):
935 def depart_row(self, node):
936 pass
936 pass
937
937
938 def visit_section(self, node):
938 def visit_section(self, node):
939 self.section_level += 1
939 self.section_level += 1
940
940
941 def depart_section(self, node):
941 def depart_section(self, node):
942 self.section_level -= 1
942 self.section_level -= 1
943
943
944 def visit_status(self, node):
944 def visit_status(self, node):
945 self.visit_docinfo_item(node, 'status')
945 self.visit_docinfo_item(node, 'status')
946
946
947 depart_status = depart_docinfo_item
947 depart_status = depart_docinfo_item
948
948
949 def visit_strong(self, node):
949 def visit_strong(self, node):
950 self.body.append(self.defs['strong'][0])
950 self.body.append(self.defs['strong'][0])
951
951
952 def depart_strong(self, node):
952 def depart_strong(self, node):
953 self.body.append(self.defs['strong'][1])
953 self.body.append(self.defs['strong'][1])
954
954
955 def visit_substitution_definition(self, node):
955 def visit_substitution_definition(self, node):
956 """Internal only."""
956 """Internal only."""
957 raise nodes.SkipNode
957 raise nodes.SkipNode
958
958
959 def visit_substitution_reference(self, node):
959 def visit_substitution_reference(self, node):
960 self.document.reporter.warning('"substitution_reference" not supported',
960 self.document.reporter.warning('"substitution_reference" not supported',
961 base_node=node)
961 base_node=node)
962
962
963 def visit_subtitle(self, node):
963 def visit_subtitle(self, node):
964 if isinstance(node.parent, nodes.sidebar):
964 if isinstance(node.parent, nodes.sidebar):
965 self.body.append(self.defs['strong'][0])
965 self.body.append(self.defs['strong'][0])
966 elif isinstance(node.parent, nodes.document):
966 elif isinstance(node.parent, nodes.document):
967 self.visit_docinfo_item(node, 'subtitle')
967 self.visit_docinfo_item(node, 'subtitle')
968 elif isinstance(node.parent, nodes.section):
968 elif isinstance(node.parent, nodes.section):
969 self.body.append(self.defs['strong'][0])
969 self.body.append(self.defs['strong'][0])
970
970
971 def depart_subtitle(self, node):
971 def depart_subtitle(self, node):
972 # document subtitle calls SkipNode
972 # document subtitle calls SkipNode
973 self.body.append(self.defs['strong'][1]+'\n.PP\n')
973 self.body.append(self.defs['strong'][1]+'\n.PP\n')
974
974
975 def visit_system_message(self, node):
975 def visit_system_message(self, node):
976 # TODO add report_level
976 # TODO add report_level
977 #if node['level'] < self.document.reporter['writer'].report_level:
977 #if node['level'] < self.document.reporter['writer'].report_level:
978 # Level is too low to display:
978 # Level is too low to display:
979 # raise nodes.SkipNode
979 # raise nodes.SkipNode
980 attr = {}
980 attr = {}
981 backref_text = ''
981 backref_text = ''
982 if node.hasattr('id'):
982 if node.hasattr('id'):
983 attr['name'] = node['id']
983 attr['name'] = node['id']
984 if node.hasattr('line'):
984 if node.hasattr('line'):
985 line = ', line %s' % node['line']
985 line = ', line %s' % node['line']
986 else:
986 else:
987 line = ''
987 line = ''
988 self.body.append('.IP "System Message: %s/%s (%s:%s)"\n'
988 self.body.append('.IP "System Message: %s/%s (%s:%s)"\n'
989 % (node['type'], node['level'], node['source'], line))
989 % (node['type'], node['level'], node['source'], line))
990
990
991 def depart_system_message(self, node):
991 def depart_system_message(self, node):
992 pass
992 pass
993
993
994 def visit_table(self, node):
994 def visit_table(self, node):
995 self._active_table = Table()
995 self._active_table = Table()
996
996
997 def depart_table(self, node):
997 def depart_table(self, node):
998 self.ensure_eol()
998 self.ensure_eol()
999 self.body.extend(self._active_table.as_list())
999 self.body.extend(self._active_table.as_list())
1000 self._active_table = None
1000 self._active_table = None
1001
1001
1002 def visit_target(self, node):
1002 def visit_target(self, node):
1003 # targets are in-document hyper targets, without any use for man-pages.
1003 # targets are in-document hyper targets, without any use for man-pages.
1004 raise nodes.SkipNode
1004 raise nodes.SkipNode
1005
1005
1006 def visit_tbody(self, node):
1006 def visit_tbody(self, node):
1007 pass
1007 pass
1008
1008
1009 def depart_tbody(self, node):
1009 def depart_tbody(self, node):
1010 pass
1010 pass
1011
1011
1012 def visit_term(self, node):
1012 def visit_term(self, node):
1013 self.body.append(self.defs['term'][0])
1013 self.body.append(self.defs['term'][0])
1014
1014
1015 def depart_term(self, node):
1015 def depart_term(self, node):
1016 self.body.append(self.defs['term'][1])
1016 self.body.append(self.defs['term'][1])
1017
1017
1018 def visit_tgroup(self, node):
1018 def visit_tgroup(self, node):
1019 pass
1019 pass
1020
1020
1021 def depart_tgroup(self, node):
1021 def depart_tgroup(self, node):
1022 pass
1022 pass
1023
1023
1024 def visit_thead(self, node):
1024 def visit_thead(self, node):
1025 # MAYBE double line '='
1025 # MAYBE double line '='
1026 pass
1026 pass
1027
1027
1028 def depart_thead(self, node):
1028 def depart_thead(self, node):
1029 # MAYBE double line '='
1029 # MAYBE double line '='
1030 pass
1030 pass
1031
1031
1032 def visit_tip(self, node):
1032 def visit_tip(self, node):
1033 self.visit_admonition(node, 'tip')
1033 self.visit_admonition(node, 'tip')
1034
1034
1035 depart_tip = depart_admonition
1035 depart_tip = depart_admonition
1036
1036
1037 def visit_title(self, node):
1037 def visit_title(self, node):
1038 if isinstance(node.parent, nodes.topic):
1038 if isinstance(node.parent, nodes.topic):
1039 self.body.append(self.defs['topic-title'][0])
1039 self.body.append(self.defs['topic-title'][0])
1040 elif isinstance(node.parent, nodes.sidebar):
1040 elif isinstance(node.parent, nodes.sidebar):
1041 self.body.append(self.defs['sidebar-title'][0])
1041 self.body.append(self.defs['sidebar-title'][0])
1042 elif isinstance(node.parent, nodes.admonition):
1042 elif isinstance(node.parent, nodes.admonition):
1043 self.body.append('.IP "')
1043 self.body.append('.IP "')
1044 elif self.section_level == 0:
1044 elif self.section_level == 0:
1045 self._docinfo['title'] = node.astext()
1045 self._docinfo['title'] = node.astext()
1046 # document title for .TH
1046 # document title for .TH
1047 self._docinfo['title_upper'] = node.astext().upper()
1047 self._docinfo['title_upper'] = node.astext().upper()
1048 raise nodes.SkipNode
1048 raise nodes.SkipNode
1049 elif self.section_level == 1:
1049 elif self.section_level == 1:
1050 self.body.append('.SH ')
1050 self.body.append('.SH ')
1051 for n in node.traverse(nodes.Text):
1051 for n in node.traverse(nodes.Text):
1052 n.parent.replace(n, nodes.Text(n.astext().upper()))
1052 n.parent.replace(n, nodes.Text(n.astext().upper()))
1053 else:
1053 else:
1054 self.body.append('.SS ')
1054 self.body.append('.SS ')
1055
1055
1056 def depart_title(self, node):
1056 def depart_title(self, node):
1057 if isinstance(node.parent, nodes.admonition):
1057 if isinstance(node.parent, nodes.admonition):
1058 self.body.append('"')
1058 self.body.append('"')
1059 self.body.append('\n')
1059 self.body.append('\n')
1060
1060
1061 def visit_title_reference(self, node):
1061 def visit_title_reference(self, node):
1062 """inline citation reference"""
1062 """inline citation reference"""
1063 self.body.append(self.defs['title_reference'][0])
1063 self.body.append(self.defs['title_reference'][0])
1064
1064
1065 def depart_title_reference(self, node):
1065 def depart_title_reference(self, node):
1066 self.body.append(self.defs['title_reference'][1])
1066 self.body.append(self.defs['title_reference'][1])
1067
1067
1068 def visit_topic(self, node):
1068 def visit_topic(self, node):
1069 pass
1069 pass
1070
1070
1071 def depart_topic(self, node):
1071 def depart_topic(self, node):
1072 pass
1072 pass
1073
1073
1074 def visit_sidebar(self, node):
1074 def visit_sidebar(self, node):
1075 pass
1075 pass
1076
1076
1077 def depart_sidebar(self, node):
1077 def depart_sidebar(self, node):
1078 pass
1078 pass
1079
1079
1080 def visit_rubric(self, node):
1080 def visit_rubric(self, node):
1081 pass
1081 pass
1082
1082
1083 def depart_rubric(self, node):
1083 def depart_rubric(self, node):
1084 pass
1084 pass
1085
1085
1086 def visit_transition(self, node):
1086 def visit_transition(self, node):
1087 # .PP Begin a new paragraph and reset prevailing indent.
1087 # .PP Begin a new paragraph and reset prevailing indent.
1088 # .sp N leaves N lines of blank space.
1088 # .sp N leaves N lines of blank space.
1089 # .ce centers the next line
1089 # .ce centers the next line
1090 self.body.append('\n.sp\n.ce\n----\n')
1090 self.body.append('\n.sp\n.ce\n----\n')
1091
1091
1092 def depart_transition(self, node):
1092 def depart_transition(self, node):
1093 self.body.append('\n.ce 0\n.sp\n')
1093 self.body.append('\n.ce 0\n.sp\n')
1094
1094
1095 def visit_version(self, node):
1095 def visit_version(self, node):
1096 self.visit_docinfo_item(node, 'version')
1096 self.visit_docinfo_item(node, 'version')
1097
1097
1098 def visit_warning(self, node):
1098 def visit_warning(self, node):
1099 self.visit_admonition(node, 'warning')
1099 self.visit_admonition(node, 'warning')
1100
1100
1101 depart_warning = depart_admonition
1101 depart_warning = depart_admonition
1102
1102
1103 def unimplemented_visit(self, node):
1103 def unimplemented_visit(self, node):
1104 raise NotImplementedError('visiting unimplemented node type: %s'
1104 raise NotImplementedError('visiting unimplemented node type: %s'
1105 % node.__class__.__name__)
1105 % node.__class__.__name__)
1106
1106
1107 # vim: set fileencoding=utf-8 et ts=4 ai :
1107 # vim: set fileencoding=utf-8 et ts=4 ai :
@@ -1,466 +1,466 b''
1 # This library is free software; you can redistribute it and/or
1 # This library is free software; you can redistribute it and/or
2 # modify it under the terms of the GNU Lesser General Public
2 # modify it under the terms of the GNU Lesser General Public
3 # License as published by the Free Software Foundation; either
3 # License as published by the Free Software Foundation; either
4 # version 2.1 of the License, or (at your option) any later version.
4 # version 2.1 of the License, or (at your option) any later version.
5 #
5 #
6 # This library is distributed in the hope that it will be useful,
6 # This library is distributed in the hope that it will be useful,
7 # but WITHOUT ANY WARRANTY; without even the implied warranty of
7 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9 # Lesser General Public License for more details.
9 # Lesser General Public License for more details.
10 #
10 #
11 # You should have received a copy of the GNU Lesser General Public
11 # You should have received a copy of the GNU Lesser General Public
12 # License along with this library; if not, write to the
12 # License along with this library; if not, write to the
13 # Free Software Foundation, Inc.,
13 # Free Software Foundation, Inc.,
14 # 59 Temple Place, Suite 330,
14 # 59 Temple Place, Suite 330,
15 # Boston, MA 02111-1307 USA
15 # Boston, MA 02111-1307 USA
16
16
17 # This file is part of urlgrabber, a high-level cross-protocol url-grabber
17 # This file is part of urlgrabber, a high-level cross-protocol url-grabber
18 # Copyright 2002-2004 Michael D. Stenner, Ryan Tomayko
18 # Copyright 2002-2004 Michael D. Stenner, Ryan Tomayko
19
19
20 # $Id: byterange.py,v 1.9 2005/02/14 21:55:07 mstenner Exp $
20 # $Id: byterange.py,v 1.9 2005/02/14 21:55:07 mstenner Exp $
21
21
22 import os
22 import os
23 import stat
23 import stat
24 import urllib
24 import urllib
25 import urllib2
25 import urllib2
26 import email.Utils
26 import email.Utils
27
27
28 class RangeError(IOError):
28 class RangeError(IOError):
29 """Error raised when an unsatisfiable range is requested."""
29 """Error raised when an unsatisfiable range is requested."""
30 pass
30 pass
31
31
32 class HTTPRangeHandler(urllib2.BaseHandler):
32 class HTTPRangeHandler(urllib2.BaseHandler):
33 """Handler that enables HTTP Range headers.
33 """Handler that enables HTTP Range headers.
34
34
35 This was extremely simple. The Range header is a HTTP feature to
35 This was extremely simple. The Range header is a HTTP feature to
36 begin with so all this class does is tell urllib2 that the
36 begin with so all this class does is tell urllib2 that the
37 "206 Partial Content" reponse from the HTTP server is what we
37 "206 Partial Content" reponse from the HTTP server is what we
38 expected.
38 expected.
39
39
40 Example:
40 Example:
41 import urllib2
41 import urllib2
42 import byterange
42 import byterange
43
43
44 range_handler = range.HTTPRangeHandler()
44 range_handler = range.HTTPRangeHandler()
45 opener = urllib2.build_opener(range_handler)
45 opener = urllib2.build_opener(range_handler)
46
46
47 # install it
47 # install it
48 urllib2.install_opener(opener)
48 urllib2.install_opener(opener)
49
49
50 # create Request and set Range header
50 # create Request and set Range header
51 req = urllib2.Request('http://www.python.org/')
51 req = urllib2.Request('http://www.python.org/')
52 req.header['Range'] = 'bytes=30-50'
52 req.header['Range'] = 'bytes=30-50'
53 f = urllib2.urlopen(req)
53 f = urllib2.urlopen(req)
54 """
54 """
55
55
56 def http_error_206(self, req, fp, code, msg, hdrs):
56 def http_error_206(self, req, fp, code, msg, hdrs):
57 # 206 Partial Content Response
57 # 206 Partial Content Response
58 r = urllib.addinfourl(fp, hdrs, req.get_full_url())
58 r = urllib.addinfourl(fp, hdrs, req.get_full_url())
59 r.code = code
59 r.code = code
60 r.msg = msg
60 r.msg = msg
61 return r
61 return r
62
62
63 def http_error_416(self, req, fp, code, msg, hdrs):
63 def http_error_416(self, req, fp, code, msg, hdrs):
64 # HTTP's Range Not Satisfiable error
64 # HTTP's Range Not Satisfiable error
65 raise RangeError('Requested Range Not Satisfiable')
65 raise RangeError('Requested Range Not Satisfiable')
66
66
67 class RangeableFileObject:
67 class RangeableFileObject(object):
68 """File object wrapper to enable raw range handling.
68 """File object wrapper to enable raw range handling.
69 This was implemented primarilary for handling range
69 This was implemented primarilary for handling range
70 specifications for file:// urls. This object effectively makes
70 specifications for file:// urls. This object effectively makes
71 a file object look like it consists only of a range of bytes in
71 a file object look like it consists only of a range of bytes in
72 the stream.
72 the stream.
73
73
74 Examples:
74 Examples:
75 # expose 10 bytes, starting at byte position 20, from
75 # expose 10 bytes, starting at byte position 20, from
76 # /etc/aliases.
76 # /etc/aliases.
77 >>> fo = RangeableFileObject(file('/etc/passwd', 'r'), (20,30))
77 >>> fo = RangeableFileObject(file('/etc/passwd', 'r'), (20,30))
78 # seek seeks within the range (to position 23 in this case)
78 # seek seeks within the range (to position 23 in this case)
79 >>> fo.seek(3)
79 >>> fo.seek(3)
80 # tell tells where your at _within the range_ (position 3 in
80 # tell tells where your at _within the range_ (position 3 in
81 # this case)
81 # this case)
82 >>> fo.tell()
82 >>> fo.tell()
83 # read EOFs if an attempt is made to read past the last
83 # read EOFs if an attempt is made to read past the last
84 # byte in the range. the following will return only 7 bytes.
84 # byte in the range. the following will return only 7 bytes.
85 >>> fo.read(30)
85 >>> fo.read(30)
86 """
86 """
87
87
88 def __init__(self, fo, rangetup):
88 def __init__(self, fo, rangetup):
89 """Create a RangeableFileObject.
89 """Create a RangeableFileObject.
90 fo -- a file like object. only the read() method need be
90 fo -- a file like object. only the read() method need be
91 supported but supporting an optimized seek() is
91 supported but supporting an optimized seek() is
92 preferable.
92 preferable.
93 rangetup -- a (firstbyte,lastbyte) tuple specifying the range
93 rangetup -- a (firstbyte,lastbyte) tuple specifying the range
94 to work over.
94 to work over.
95 The file object provided is assumed to be at byte offset 0.
95 The file object provided is assumed to be at byte offset 0.
96 """
96 """
97 self.fo = fo
97 self.fo = fo
98 (self.firstbyte, self.lastbyte) = range_tuple_normalize(rangetup)
98 (self.firstbyte, self.lastbyte) = range_tuple_normalize(rangetup)
99 self.realpos = 0
99 self.realpos = 0
100 self._do_seek(self.firstbyte)
100 self._do_seek(self.firstbyte)
101
101
102 def __getattr__(self, name):
102 def __getattr__(self, name):
103 """This effectively allows us to wrap at the instance level.
103 """This effectively allows us to wrap at the instance level.
104 Any attribute not found in _this_ object will be searched for
104 Any attribute not found in _this_ object will be searched for
105 in self.fo. This includes methods."""
105 in self.fo. This includes methods."""
106 if hasattr(self.fo, name):
106 if hasattr(self.fo, name):
107 return getattr(self.fo, name)
107 return getattr(self.fo, name)
108 raise AttributeError(name)
108 raise AttributeError(name)
109
109
110 def tell(self):
110 def tell(self):
111 """Return the position within the range.
111 """Return the position within the range.
112 This is different from fo.seek in that position 0 is the
112 This is different from fo.seek in that position 0 is the
113 first byte position of the range tuple. For example, if
113 first byte position of the range tuple. For example, if
114 this object was created with a range tuple of (500,899),
114 this object was created with a range tuple of (500,899),
115 tell() will return 0 when at byte position 500 of the file.
115 tell() will return 0 when at byte position 500 of the file.
116 """
116 """
117 return (self.realpos - self.firstbyte)
117 return (self.realpos - self.firstbyte)
118
118
119 def seek(self, offset, whence=0):
119 def seek(self, offset, whence=0):
120 """Seek within the byte range.
120 """Seek within the byte range.
121 Positioning is identical to that described under tell().
121 Positioning is identical to that described under tell().
122 """
122 """
123 assert whence in (0, 1, 2)
123 assert whence in (0, 1, 2)
124 if whence == 0: # absolute seek
124 if whence == 0: # absolute seek
125 realoffset = self.firstbyte + offset
125 realoffset = self.firstbyte + offset
126 elif whence == 1: # relative seek
126 elif whence == 1: # relative seek
127 realoffset = self.realpos + offset
127 realoffset = self.realpos + offset
128 elif whence == 2: # absolute from end of file
128 elif whence == 2: # absolute from end of file
129 # XXX: are we raising the right Error here?
129 # XXX: are we raising the right Error here?
130 raise IOError('seek from end of file not supported.')
130 raise IOError('seek from end of file not supported.')
131
131
132 # do not allow seek past lastbyte in range
132 # do not allow seek past lastbyte in range
133 if self.lastbyte and (realoffset >= self.lastbyte):
133 if self.lastbyte and (realoffset >= self.lastbyte):
134 realoffset = self.lastbyte
134 realoffset = self.lastbyte
135
135
136 self._do_seek(realoffset - self.realpos)
136 self._do_seek(realoffset - self.realpos)
137
137
138 def read(self, size=-1):
138 def read(self, size=-1):
139 """Read within the range.
139 """Read within the range.
140 This method will limit the size read based on the range.
140 This method will limit the size read based on the range.
141 """
141 """
142 size = self._calc_read_size(size)
142 size = self._calc_read_size(size)
143 rslt = self.fo.read(size)
143 rslt = self.fo.read(size)
144 self.realpos += len(rslt)
144 self.realpos += len(rslt)
145 return rslt
145 return rslt
146
146
147 def readline(self, size=-1):
147 def readline(self, size=-1):
148 """Read lines within the range.
148 """Read lines within the range.
149 This method will limit the size read based on the range.
149 This method will limit the size read based on the range.
150 """
150 """
151 size = self._calc_read_size(size)
151 size = self._calc_read_size(size)
152 rslt = self.fo.readline(size)
152 rslt = self.fo.readline(size)
153 self.realpos += len(rslt)
153 self.realpos += len(rslt)
154 return rslt
154 return rslt
155
155
156 def _calc_read_size(self, size):
156 def _calc_read_size(self, size):
157 """Handles calculating the amount of data to read based on
157 """Handles calculating the amount of data to read based on
158 the range.
158 the range.
159 """
159 """
160 if self.lastbyte:
160 if self.lastbyte:
161 if size > -1:
161 if size > -1:
162 if ((self.realpos + size) >= self.lastbyte):
162 if ((self.realpos + size) >= self.lastbyte):
163 size = (self.lastbyte - self.realpos)
163 size = (self.lastbyte - self.realpos)
164 else:
164 else:
165 size = (self.lastbyte - self.realpos)
165 size = (self.lastbyte - self.realpos)
166 return size
166 return size
167
167
168 def _do_seek(self, offset):
168 def _do_seek(self, offset):
169 """Seek based on whether wrapped object supports seek().
169 """Seek based on whether wrapped object supports seek().
170 offset is relative to the current position (self.realpos).
170 offset is relative to the current position (self.realpos).
171 """
171 """
172 assert offset >= 0
172 assert offset >= 0
173 if not hasattr(self.fo, 'seek'):
173 if not hasattr(self.fo, 'seek'):
174 self._poor_mans_seek(offset)
174 self._poor_mans_seek(offset)
175 else:
175 else:
176 self.fo.seek(self.realpos + offset)
176 self.fo.seek(self.realpos + offset)
177 self.realpos += offset
177 self.realpos += offset
178
178
179 def _poor_mans_seek(self, offset):
179 def _poor_mans_seek(self, offset):
180 """Seek by calling the wrapped file objects read() method.
180 """Seek by calling the wrapped file objects read() method.
181 This is used for file like objects that do not have native
181 This is used for file like objects that do not have native
182 seek support. The wrapped objects read() method is called
182 seek support. The wrapped objects read() method is called
183 to manually seek to the desired position.
183 to manually seek to the desired position.
184 offset -- read this number of bytes from the wrapped
184 offset -- read this number of bytes from the wrapped
185 file object.
185 file object.
186 raise RangeError if we encounter EOF before reaching the
186 raise RangeError if we encounter EOF before reaching the
187 specified offset.
187 specified offset.
188 """
188 """
189 pos = 0
189 pos = 0
190 bufsize = 1024
190 bufsize = 1024
191 while pos < offset:
191 while pos < offset:
192 if (pos + bufsize) > offset:
192 if (pos + bufsize) > offset:
193 bufsize = offset - pos
193 bufsize = offset - pos
194 buf = self.fo.read(bufsize)
194 buf = self.fo.read(bufsize)
195 if len(buf) != bufsize:
195 if len(buf) != bufsize:
196 raise RangeError('Requested Range Not Satisfiable')
196 raise RangeError('Requested Range Not Satisfiable')
197 pos += bufsize
197 pos += bufsize
198
198
199 class FileRangeHandler(urllib2.FileHandler):
199 class FileRangeHandler(urllib2.FileHandler):
200 """FileHandler subclass that adds Range support.
200 """FileHandler subclass that adds Range support.
201 This class handles Range headers exactly like an HTTP
201 This class handles Range headers exactly like an HTTP
202 server would.
202 server would.
203 """
203 """
204 def open_local_file(self, req):
204 def open_local_file(self, req):
205 import mimetypes
205 import mimetypes
206 import email
206 import email
207 host = req.get_host()
207 host = req.get_host()
208 file = req.get_selector()
208 file = req.get_selector()
209 localfile = urllib.url2pathname(file)
209 localfile = urllib.url2pathname(file)
210 stats = os.stat(localfile)
210 stats = os.stat(localfile)
211 size = stats[stat.ST_SIZE]
211 size = stats[stat.ST_SIZE]
212 modified = email.Utils.formatdate(stats[stat.ST_MTIME])
212 modified = email.Utils.formatdate(stats[stat.ST_MTIME])
213 mtype = mimetypes.guess_type(file)[0]
213 mtype = mimetypes.guess_type(file)[0]
214 if host:
214 if host:
215 host, port = urllib.splitport(host)
215 host, port = urllib.splitport(host)
216 if port or socket.gethostbyname(host) not in self.get_names():
216 if port or socket.gethostbyname(host) not in self.get_names():
217 raise urllib2.URLError('file not on local host')
217 raise urllib2.URLError('file not on local host')
218 fo = open(localfile,'rb')
218 fo = open(localfile,'rb')
219 brange = req.headers.get('Range', None)
219 brange = req.headers.get('Range', None)
220 brange = range_header_to_tuple(brange)
220 brange = range_header_to_tuple(brange)
221 assert brange != ()
221 assert brange != ()
222 if brange:
222 if brange:
223 (fb, lb) = brange
223 (fb, lb) = brange
224 if lb == '':
224 if lb == '':
225 lb = size
225 lb = size
226 if fb < 0 or fb > size or lb > size:
226 if fb < 0 or fb > size or lb > size:
227 raise RangeError('Requested Range Not Satisfiable')
227 raise RangeError('Requested Range Not Satisfiable')
228 size = (lb - fb)
228 size = (lb - fb)
229 fo = RangeableFileObject(fo, (fb, lb))
229 fo = RangeableFileObject(fo, (fb, lb))
230 headers = email.message_from_string(
230 headers = email.message_from_string(
231 'Content-Type: %s\nContent-Length: %d\nLast-Modified: %s\n' %
231 'Content-Type: %s\nContent-Length: %d\nLast-Modified: %s\n' %
232 (mtype or 'text/plain', size, modified))
232 (mtype or 'text/plain', size, modified))
233 return urllib.addinfourl(fo, headers, 'file:'+file)
233 return urllib.addinfourl(fo, headers, 'file:'+file)
234
234
235
235
236 # FTP Range Support
236 # FTP Range Support
237 # Unfortunately, a large amount of base FTP code had to be copied
237 # Unfortunately, a large amount of base FTP code had to be copied
238 # from urllib and urllib2 in order to insert the FTP REST command.
238 # from urllib and urllib2 in order to insert the FTP REST command.
239 # Code modifications for range support have been commented as
239 # Code modifications for range support have been commented as
240 # follows:
240 # follows:
241 # -- range support modifications start/end here
241 # -- range support modifications start/end here
242
242
243 from urllib import splitport, splituser, splitpasswd, splitattr, \
243 from urllib import splitport, splituser, splitpasswd, splitattr, \
244 unquote, addclosehook, addinfourl
244 unquote, addclosehook, addinfourl
245 import ftplib
245 import ftplib
246 import socket
246 import socket
247 import sys
247 import sys
248 import mimetypes
248 import mimetypes
249 import email
249 import email
250
250
251 class FTPRangeHandler(urllib2.FTPHandler):
251 class FTPRangeHandler(urllib2.FTPHandler):
252 def ftp_open(self, req):
252 def ftp_open(self, req):
253 host = req.get_host()
253 host = req.get_host()
254 if not host:
254 if not host:
255 raise IOError('ftp error', 'no host given')
255 raise IOError('ftp error', 'no host given')
256 host, port = splitport(host)
256 host, port = splitport(host)
257 if port is None:
257 if port is None:
258 port = ftplib.FTP_PORT
258 port = ftplib.FTP_PORT
259 else:
259 else:
260 port = int(port)
260 port = int(port)
261
261
262 # username/password handling
262 # username/password handling
263 user, host = splituser(host)
263 user, host = splituser(host)
264 if user:
264 if user:
265 user, passwd = splitpasswd(user)
265 user, passwd = splitpasswd(user)
266 else:
266 else:
267 passwd = None
267 passwd = None
268 host = unquote(host)
268 host = unquote(host)
269 user = unquote(user or '')
269 user = unquote(user or '')
270 passwd = unquote(passwd or '')
270 passwd = unquote(passwd or '')
271
271
272 try:
272 try:
273 host = socket.gethostbyname(host)
273 host = socket.gethostbyname(host)
274 except socket.error, msg:
274 except socket.error, msg:
275 raise urllib2.URLError(msg)
275 raise urllib2.URLError(msg)
276 path, attrs = splitattr(req.get_selector())
276 path, attrs = splitattr(req.get_selector())
277 dirs = path.split('/')
277 dirs = path.split('/')
278 dirs = map(unquote, dirs)
278 dirs = map(unquote, dirs)
279 dirs, file = dirs[:-1], dirs[-1]
279 dirs, file = dirs[:-1], dirs[-1]
280 if dirs and not dirs[0]:
280 if dirs and not dirs[0]:
281 dirs = dirs[1:]
281 dirs = dirs[1:]
282 try:
282 try:
283 fw = self.connect_ftp(user, passwd, host, port, dirs)
283 fw = self.connect_ftp(user, passwd, host, port, dirs)
284 type = file and 'I' or 'D'
284 type = file and 'I' or 'D'
285 for attr in attrs:
285 for attr in attrs:
286 attr, value = splitattr(attr)
286 attr, value = splitattr(attr)
287 if attr.lower() == 'type' and \
287 if attr.lower() == 'type' and \
288 value in ('a', 'A', 'i', 'I', 'd', 'D'):
288 value in ('a', 'A', 'i', 'I', 'd', 'D'):
289 type = value.upper()
289 type = value.upper()
290
290
291 # -- range support modifications start here
291 # -- range support modifications start here
292 rest = None
292 rest = None
293 range_tup = range_header_to_tuple(req.headers.get('Range', None))
293 range_tup = range_header_to_tuple(req.headers.get('Range', None))
294 assert range_tup != ()
294 assert range_tup != ()
295 if range_tup:
295 if range_tup:
296 (fb, lb) = range_tup
296 (fb, lb) = range_tup
297 if fb > 0:
297 if fb > 0:
298 rest = fb
298 rest = fb
299 # -- range support modifications end here
299 # -- range support modifications end here
300
300
301 fp, retrlen = fw.retrfile(file, type, rest)
301 fp, retrlen = fw.retrfile(file, type, rest)
302
302
303 # -- range support modifications start here
303 # -- range support modifications start here
304 if range_tup:
304 if range_tup:
305 (fb, lb) = range_tup
305 (fb, lb) = range_tup
306 if lb == '':
306 if lb == '':
307 if retrlen is None or retrlen == 0:
307 if retrlen is None or retrlen == 0:
308 raise RangeError('Requested Range Not Satisfiable due'
308 raise RangeError('Requested Range Not Satisfiable due'
309 ' to unobtainable file length.')
309 ' to unobtainable file length.')
310 lb = retrlen
310 lb = retrlen
311 retrlen = lb - fb
311 retrlen = lb - fb
312 if retrlen < 0:
312 if retrlen < 0:
313 # beginning of range is larger than file
313 # beginning of range is larger than file
314 raise RangeError('Requested Range Not Satisfiable')
314 raise RangeError('Requested Range Not Satisfiable')
315 else:
315 else:
316 retrlen = lb - fb
316 retrlen = lb - fb
317 fp = RangeableFileObject(fp, (0, retrlen))
317 fp = RangeableFileObject(fp, (0, retrlen))
318 # -- range support modifications end here
318 # -- range support modifications end here
319
319
320 headers = ""
320 headers = ""
321 mtype = mimetypes.guess_type(req.get_full_url())[0]
321 mtype = mimetypes.guess_type(req.get_full_url())[0]
322 if mtype:
322 if mtype:
323 headers += "Content-Type: %s\n" % mtype
323 headers += "Content-Type: %s\n" % mtype
324 if retrlen is not None and retrlen >= 0:
324 if retrlen is not None and retrlen >= 0:
325 headers += "Content-Length: %d\n" % retrlen
325 headers += "Content-Length: %d\n" % retrlen
326 headers = email.message_from_string(headers)
326 headers = email.message_from_string(headers)
327 return addinfourl(fp, headers, req.get_full_url())
327 return addinfourl(fp, headers, req.get_full_url())
328 except ftplib.all_errors, msg:
328 except ftplib.all_errors, msg:
329 raise IOError('ftp error', msg), sys.exc_info()[2]
329 raise IOError('ftp error', msg), sys.exc_info()[2]
330
330
331 def connect_ftp(self, user, passwd, host, port, dirs):
331 def connect_ftp(self, user, passwd, host, port, dirs):
332 fw = ftpwrapper(user, passwd, host, port, dirs)
332 fw = ftpwrapper(user, passwd, host, port, dirs)
333 return fw
333 return fw
334
334
335 class ftpwrapper(urllib.ftpwrapper):
335 class ftpwrapper(urllib.ftpwrapper):
336 # range support note:
336 # range support note:
337 # this ftpwrapper code is copied directly from
337 # this ftpwrapper code is copied directly from
338 # urllib. The only enhancement is to add the rest
338 # urllib. The only enhancement is to add the rest
339 # argument and pass it on to ftp.ntransfercmd
339 # argument and pass it on to ftp.ntransfercmd
340 def retrfile(self, file, type, rest=None):
340 def retrfile(self, file, type, rest=None):
341 self.endtransfer()
341 self.endtransfer()
342 if type in ('d', 'D'):
342 if type in ('d', 'D'):
343 cmd = 'TYPE A'
343 cmd = 'TYPE A'
344 isdir = 1
344 isdir = 1
345 else:
345 else:
346 cmd = 'TYPE ' + type
346 cmd = 'TYPE ' + type
347 isdir = 0
347 isdir = 0
348 try:
348 try:
349 self.ftp.voidcmd(cmd)
349 self.ftp.voidcmd(cmd)
350 except ftplib.all_errors:
350 except ftplib.all_errors:
351 self.init()
351 self.init()
352 self.ftp.voidcmd(cmd)
352 self.ftp.voidcmd(cmd)
353 conn = None
353 conn = None
354 if file and not isdir:
354 if file and not isdir:
355 # Use nlst to see if the file exists at all
355 # Use nlst to see if the file exists at all
356 try:
356 try:
357 self.ftp.nlst(file)
357 self.ftp.nlst(file)
358 except ftplib.error_perm, reason:
358 except ftplib.error_perm, reason:
359 raise IOError('ftp error', reason), sys.exc_info()[2]
359 raise IOError('ftp error', reason), sys.exc_info()[2]
360 # Restore the transfer mode!
360 # Restore the transfer mode!
361 self.ftp.voidcmd(cmd)
361 self.ftp.voidcmd(cmd)
362 # Try to retrieve as a file
362 # Try to retrieve as a file
363 try:
363 try:
364 cmd = 'RETR ' + file
364 cmd = 'RETR ' + file
365 conn = self.ftp.ntransfercmd(cmd, rest)
365 conn = self.ftp.ntransfercmd(cmd, rest)
366 except ftplib.error_perm, reason:
366 except ftplib.error_perm, reason:
367 if str(reason).startswith('501'):
367 if str(reason).startswith('501'):
368 # workaround for REST not supported error
368 # workaround for REST not supported error
369 fp, retrlen = self.retrfile(file, type)
369 fp, retrlen = self.retrfile(file, type)
370 fp = RangeableFileObject(fp, (rest,''))
370 fp = RangeableFileObject(fp, (rest,''))
371 return (fp, retrlen)
371 return (fp, retrlen)
372 elif not str(reason).startswith('550'):
372 elif not str(reason).startswith('550'):
373 raise IOError('ftp error', reason), sys.exc_info()[2]
373 raise IOError('ftp error', reason), sys.exc_info()[2]
374 if not conn:
374 if not conn:
375 # Set transfer mode to ASCII!
375 # Set transfer mode to ASCII!
376 self.ftp.voidcmd('TYPE A')
376 self.ftp.voidcmd('TYPE A')
377 # Try a directory listing
377 # Try a directory listing
378 if file:
378 if file:
379 cmd = 'LIST ' + file
379 cmd = 'LIST ' + file
380 else:
380 else:
381 cmd = 'LIST'
381 cmd = 'LIST'
382 conn = self.ftp.ntransfercmd(cmd)
382 conn = self.ftp.ntransfercmd(cmd)
383 self.busy = 1
383 self.busy = 1
384 # Pass back both a suitably decorated object and a retrieval length
384 # Pass back both a suitably decorated object and a retrieval length
385 return (addclosehook(conn[0].makefile('rb'),
385 return (addclosehook(conn[0].makefile('rb'),
386 self.endtransfer), conn[1])
386 self.endtransfer), conn[1])
387
387
388
388
389 ####################################################################
389 ####################################################################
390 # Range Tuple Functions
390 # Range Tuple Functions
391 # XXX: These range tuple functions might go better in a class.
391 # XXX: These range tuple functions might go better in a class.
392
392
393 _rangere = None
393 _rangere = None
394 def range_header_to_tuple(range_header):
394 def range_header_to_tuple(range_header):
395 """Get a (firstbyte,lastbyte) tuple from a Range header value.
395 """Get a (firstbyte,lastbyte) tuple from a Range header value.
396
396
397 Range headers have the form "bytes=<firstbyte>-<lastbyte>". This
397 Range headers have the form "bytes=<firstbyte>-<lastbyte>". This
398 function pulls the firstbyte and lastbyte values and returns
398 function pulls the firstbyte and lastbyte values and returns
399 a (firstbyte,lastbyte) tuple. If lastbyte is not specified in
399 a (firstbyte,lastbyte) tuple. If lastbyte is not specified in
400 the header value, it is returned as an empty string in the
400 the header value, it is returned as an empty string in the
401 tuple.
401 tuple.
402
402
403 Return None if range_header is None
403 Return None if range_header is None
404 Return () if range_header does not conform to the range spec
404 Return () if range_header does not conform to the range spec
405 pattern.
405 pattern.
406
406
407 """
407 """
408 global _rangere
408 global _rangere
409 if range_header is None:
409 if range_header is None:
410 return None
410 return None
411 if _rangere is None:
411 if _rangere is None:
412 import re
412 import re
413 _rangere = re.compile(r'^bytes=(\d{1,})-(\d*)')
413 _rangere = re.compile(r'^bytes=(\d{1,})-(\d*)')
414 match = _rangere.match(range_header)
414 match = _rangere.match(range_header)
415 if match:
415 if match:
416 tup = range_tuple_normalize(match.group(1, 2))
416 tup = range_tuple_normalize(match.group(1, 2))
417 if tup and tup[1]:
417 if tup and tup[1]:
418 tup = (tup[0], tup[1]+1)
418 tup = (tup[0], tup[1]+1)
419 return tup
419 return tup
420 return ()
420 return ()
421
421
422 def range_tuple_to_header(range_tup):
422 def range_tuple_to_header(range_tup):
423 """Convert a range tuple to a Range header value.
423 """Convert a range tuple to a Range header value.
424 Return a string of the form "bytes=<firstbyte>-<lastbyte>" or None
424 Return a string of the form "bytes=<firstbyte>-<lastbyte>" or None
425 if no range is needed.
425 if no range is needed.
426 """
426 """
427 if range_tup is None:
427 if range_tup is None:
428 return None
428 return None
429 range_tup = range_tuple_normalize(range_tup)
429 range_tup = range_tuple_normalize(range_tup)
430 if range_tup:
430 if range_tup:
431 if range_tup[1]:
431 if range_tup[1]:
432 range_tup = (range_tup[0], range_tup[1] - 1)
432 range_tup = (range_tup[0], range_tup[1] - 1)
433 return 'bytes=%s-%s' % range_tup
433 return 'bytes=%s-%s' % range_tup
434
434
435 def range_tuple_normalize(range_tup):
435 def range_tuple_normalize(range_tup):
436 """Normalize a (first_byte,last_byte) range tuple.
436 """Normalize a (first_byte,last_byte) range tuple.
437 Return a tuple whose first element is guaranteed to be an int
437 Return a tuple whose first element is guaranteed to be an int
438 and whose second element will be '' (meaning: the last byte) or
438 and whose second element will be '' (meaning: the last byte) or
439 an int. Finally, return None if the normalized tuple == (0,'')
439 an int. Finally, return None if the normalized tuple == (0,'')
440 as that is equivelant to retrieving the entire file.
440 as that is equivelant to retrieving the entire file.
441 """
441 """
442 if range_tup is None:
442 if range_tup is None:
443 return None
443 return None
444 # handle first byte
444 # handle first byte
445 fb = range_tup[0]
445 fb = range_tup[0]
446 if fb in (None, ''):
446 if fb in (None, ''):
447 fb = 0
447 fb = 0
448 else:
448 else:
449 fb = int(fb)
449 fb = int(fb)
450 # handle last byte
450 # handle last byte
451 try:
451 try:
452 lb = range_tup[1]
452 lb = range_tup[1]
453 except IndexError:
453 except IndexError:
454 lb = ''
454 lb = ''
455 else:
455 else:
456 if lb is None:
456 if lb is None:
457 lb = ''
457 lb = ''
458 elif lb != '':
458 elif lb != '':
459 lb = int(lb)
459 lb = int(lb)
460 # check if range is over the entire file
460 # check if range is over the entire file
461 if (fb, lb) == (0, ''):
461 if (fb, lb) == (0, ''):
462 return None
462 return None
463 # check that the range is valid
463 # check that the range is valid
464 if lb < fb:
464 if lb < fb:
465 raise RangeError('Invalid byte range: %s-%s' % (fb, lb))
465 raise RangeError('Invalid byte range: %s-%s' % (fb, lb))
466 return (fb, lb)
466 return (fb, lb)
@@ -1,319 +1,319 b''
1 # hgweb/server.py - The standalone hg web server.
1 # hgweb/server.py - The standalone hg web server.
2 #
2 #
3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
4 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 import os, sys, errno, urllib, BaseHTTPServer, socket, SocketServer, traceback
9 import os, sys, errno, urllib, BaseHTTPServer, socket, SocketServer, traceback
10 from mercurial import util, error
10 from mercurial import util, error
11 from mercurial.hgweb import common
11 from mercurial.hgweb import common
12 from mercurial.i18n import _
12 from mercurial.i18n import _
13
13
14 def _splitURI(uri):
14 def _splitURI(uri):
15 """ Return path and query splited from uri
15 """ Return path and query splited from uri
16
16
17 Just like CGI environment, the path is unquoted, the query is
17 Just like CGI environment, the path is unquoted, the query is
18 not.
18 not.
19 """
19 """
20 if '?' in uri:
20 if '?' in uri:
21 path, query = uri.split('?', 1)
21 path, query = uri.split('?', 1)
22 else:
22 else:
23 path, query = uri, ''
23 path, query = uri, ''
24 return urllib.unquote(path), query
24 return urllib.unquote(path), query
25
25
26 class _error_logger(object):
26 class _error_logger(object):
27 def __init__(self, handler):
27 def __init__(self, handler):
28 self.handler = handler
28 self.handler = handler
29 def flush(self):
29 def flush(self):
30 pass
30 pass
31 def write(self, str):
31 def write(self, str):
32 self.writelines(str.split('\n'))
32 self.writelines(str.split('\n'))
33 def writelines(self, seq):
33 def writelines(self, seq):
34 for msg in seq:
34 for msg in seq:
35 self.handler.log_error("HG error: %s", msg)
35 self.handler.log_error("HG error: %s", msg)
36
36
37 class _httprequesthandler(BaseHTTPServer.BaseHTTPRequestHandler):
37 class _httprequesthandler(BaseHTTPServer.BaseHTTPRequestHandler):
38
38
39 url_scheme = 'http'
39 url_scheme = 'http'
40
40
41 @staticmethod
41 @staticmethod
42 def preparehttpserver(httpserver, ssl_cert):
42 def preparehttpserver(httpserver, ssl_cert):
43 """Prepare .socket of new HTTPServer instance"""
43 """Prepare .socket of new HTTPServer instance"""
44 pass
44 pass
45
45
46 def __init__(self, *args, **kargs):
46 def __init__(self, *args, **kargs):
47 self.protocol_version = 'HTTP/1.1'
47 self.protocol_version = 'HTTP/1.1'
48 BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args, **kargs)
48 BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args, **kargs)
49
49
50 def _log_any(self, fp, format, *args):
50 def _log_any(self, fp, format, *args):
51 fp.write("%s - - [%s] %s\n" % (self.client_address[0],
51 fp.write("%s - - [%s] %s\n" % (self.client_address[0],
52 self.log_date_time_string(),
52 self.log_date_time_string(),
53 format % args))
53 format % args))
54 fp.flush()
54 fp.flush()
55
55
56 def log_error(self, format, *args):
56 def log_error(self, format, *args):
57 self._log_any(self.server.errorlog, format, *args)
57 self._log_any(self.server.errorlog, format, *args)
58
58
59 def log_message(self, format, *args):
59 def log_message(self, format, *args):
60 self._log_any(self.server.accesslog, format, *args)
60 self._log_any(self.server.accesslog, format, *args)
61
61
62 def log_request(self, code='-', size='-'):
62 def log_request(self, code='-', size='-'):
63 xheaders = [h for h in self.headers.items() if h[0].startswith('x-')]
63 xheaders = [h for h in self.headers.items() if h[0].startswith('x-')]
64 self.log_message('"%s" %s %s%s',
64 self.log_message('"%s" %s %s%s',
65 self.requestline, str(code), str(size),
65 self.requestline, str(code), str(size),
66 ''.join([' %s:%s' % h for h in sorted(xheaders)]))
66 ''.join([' %s:%s' % h for h in sorted(xheaders)]))
67
67
68 def do_write(self):
68 def do_write(self):
69 try:
69 try:
70 self.do_hgweb()
70 self.do_hgweb()
71 except socket.error, inst:
71 except socket.error, inst:
72 if inst[0] != errno.EPIPE:
72 if inst[0] != errno.EPIPE:
73 raise
73 raise
74
74
75 def do_POST(self):
75 def do_POST(self):
76 try:
76 try:
77 self.do_write()
77 self.do_write()
78 except Exception:
78 except Exception:
79 self._start_response("500 Internal Server Error", [])
79 self._start_response("500 Internal Server Error", [])
80 self._write("Internal Server Error")
80 self._write("Internal Server Error")
81 tb = "".join(traceback.format_exception(*sys.exc_info()))
81 tb = "".join(traceback.format_exception(*sys.exc_info()))
82 self.log_error("Exception happened during processing "
82 self.log_error("Exception happened during processing "
83 "request '%s':\n%s", self.path, tb)
83 "request '%s':\n%s", self.path, tb)
84
84
85 def do_GET(self):
85 def do_GET(self):
86 self.do_POST()
86 self.do_POST()
87
87
88 def do_hgweb(self):
88 def do_hgweb(self):
89 path, query = _splitURI(self.path)
89 path, query = _splitURI(self.path)
90
90
91 env = {}
91 env = {}
92 env['GATEWAY_INTERFACE'] = 'CGI/1.1'
92 env['GATEWAY_INTERFACE'] = 'CGI/1.1'
93 env['REQUEST_METHOD'] = self.command
93 env['REQUEST_METHOD'] = self.command
94 env['SERVER_NAME'] = self.server.server_name
94 env['SERVER_NAME'] = self.server.server_name
95 env['SERVER_PORT'] = str(self.server.server_port)
95 env['SERVER_PORT'] = str(self.server.server_port)
96 env['REQUEST_URI'] = self.path
96 env['REQUEST_URI'] = self.path
97 env['SCRIPT_NAME'] = self.server.prefix
97 env['SCRIPT_NAME'] = self.server.prefix
98 env['PATH_INFO'] = path[len(self.server.prefix):]
98 env['PATH_INFO'] = path[len(self.server.prefix):]
99 env['REMOTE_HOST'] = self.client_address[0]
99 env['REMOTE_HOST'] = self.client_address[0]
100 env['REMOTE_ADDR'] = self.client_address[0]
100 env['REMOTE_ADDR'] = self.client_address[0]
101 if query:
101 if query:
102 env['QUERY_STRING'] = query
102 env['QUERY_STRING'] = query
103
103
104 if self.headers.typeheader is None:
104 if self.headers.typeheader is None:
105 env['CONTENT_TYPE'] = self.headers.type
105 env['CONTENT_TYPE'] = self.headers.type
106 else:
106 else:
107 env['CONTENT_TYPE'] = self.headers.typeheader
107 env['CONTENT_TYPE'] = self.headers.typeheader
108 length = self.headers.getheader('content-length')
108 length = self.headers.getheader('content-length')
109 if length:
109 if length:
110 env['CONTENT_LENGTH'] = length
110 env['CONTENT_LENGTH'] = length
111 for header in [h for h in self.headers.keys()
111 for header in [h for h in self.headers.keys()
112 if h not in ('content-type', 'content-length')]:
112 if h not in ('content-type', 'content-length')]:
113 hkey = 'HTTP_' + header.replace('-', '_').upper()
113 hkey = 'HTTP_' + header.replace('-', '_').upper()
114 hval = self.headers.getheader(header)
114 hval = self.headers.getheader(header)
115 hval = hval.replace('\n', '').strip()
115 hval = hval.replace('\n', '').strip()
116 if hval:
116 if hval:
117 env[hkey] = hval
117 env[hkey] = hval
118 env['SERVER_PROTOCOL'] = self.request_version
118 env['SERVER_PROTOCOL'] = self.request_version
119 env['wsgi.version'] = (1, 0)
119 env['wsgi.version'] = (1, 0)
120 env['wsgi.url_scheme'] = self.url_scheme
120 env['wsgi.url_scheme'] = self.url_scheme
121 if env.get('HTTP_EXPECT', '').lower() == '100-continue':
121 if env.get('HTTP_EXPECT', '').lower() == '100-continue':
122 self.rfile = common.continuereader(self.rfile, self.wfile.write)
122 self.rfile = common.continuereader(self.rfile, self.wfile.write)
123
123
124 env['wsgi.input'] = self.rfile
124 env['wsgi.input'] = self.rfile
125 env['wsgi.errors'] = _error_logger(self)
125 env['wsgi.errors'] = _error_logger(self)
126 env['wsgi.multithread'] = isinstance(self.server,
126 env['wsgi.multithread'] = isinstance(self.server,
127 SocketServer.ThreadingMixIn)
127 SocketServer.ThreadingMixIn)
128 env['wsgi.multiprocess'] = isinstance(self.server,
128 env['wsgi.multiprocess'] = isinstance(self.server,
129 SocketServer.ForkingMixIn)
129 SocketServer.ForkingMixIn)
130 env['wsgi.run_once'] = 0
130 env['wsgi.run_once'] = 0
131
131
132 self.close_connection = True
132 self.close_connection = True
133 self.saved_status = None
133 self.saved_status = None
134 self.saved_headers = []
134 self.saved_headers = []
135 self.sent_headers = False
135 self.sent_headers = False
136 self.length = None
136 self.length = None
137 for chunk in self.server.application(env, self._start_response):
137 for chunk in self.server.application(env, self._start_response):
138 self._write(chunk)
138 self._write(chunk)
139
139
140 def send_headers(self):
140 def send_headers(self):
141 if not self.saved_status:
141 if not self.saved_status:
142 raise AssertionError("Sending headers before "
142 raise AssertionError("Sending headers before "
143 "start_response() called")
143 "start_response() called")
144 saved_status = self.saved_status.split(None, 1)
144 saved_status = self.saved_status.split(None, 1)
145 saved_status[0] = int(saved_status[0])
145 saved_status[0] = int(saved_status[0])
146 self.send_response(*saved_status)
146 self.send_response(*saved_status)
147 should_close = True
147 should_close = True
148 for h in self.saved_headers:
148 for h in self.saved_headers:
149 self.send_header(*h)
149 self.send_header(*h)
150 if h[0].lower() == 'content-length':
150 if h[0].lower() == 'content-length':
151 should_close = False
151 should_close = False
152 self.length = int(h[1])
152 self.length = int(h[1])
153 # The value of the Connection header is a list of case-insensitive
153 # The value of the Connection header is a list of case-insensitive
154 # tokens separated by commas and optional whitespace.
154 # tokens separated by commas and optional whitespace.
155 if 'close' in [token.strip().lower() for token in
155 if 'close' in [token.strip().lower() for token in
156 self.headers.get('connection', '').split(',')]:
156 self.headers.get('connection', '').split(',')]:
157 should_close = True
157 should_close = True
158 if should_close:
158 if should_close:
159 self.send_header('Connection', 'close')
159 self.send_header('Connection', 'close')
160 self.close_connection = should_close
160 self.close_connection = should_close
161 self.end_headers()
161 self.end_headers()
162 self.sent_headers = True
162 self.sent_headers = True
163
163
164 def _start_response(self, http_status, headers, exc_info=None):
164 def _start_response(self, http_status, headers, exc_info=None):
165 code, msg = http_status.split(None, 1)
165 code, msg = http_status.split(None, 1)
166 code = int(code)
166 code = int(code)
167 self.saved_status = http_status
167 self.saved_status = http_status
168 bad_headers = ('connection', 'transfer-encoding')
168 bad_headers = ('connection', 'transfer-encoding')
169 self.saved_headers = [h for h in headers
169 self.saved_headers = [h for h in headers
170 if h[0].lower() not in bad_headers]
170 if h[0].lower() not in bad_headers]
171 return self._write
171 return self._write
172
172
173 def _write(self, data):
173 def _write(self, data):
174 if not self.saved_status:
174 if not self.saved_status:
175 raise AssertionError("data written before start_response() called")
175 raise AssertionError("data written before start_response() called")
176 elif not self.sent_headers:
176 elif not self.sent_headers:
177 self.send_headers()
177 self.send_headers()
178 if self.length is not None:
178 if self.length is not None:
179 if len(data) > self.length:
179 if len(data) > self.length:
180 raise AssertionError("Content-length header sent, but more "
180 raise AssertionError("Content-length header sent, but more "
181 "bytes than specified are being written.")
181 "bytes than specified are being written.")
182 self.length = self.length - len(data)
182 self.length = self.length - len(data)
183 self.wfile.write(data)
183 self.wfile.write(data)
184 self.wfile.flush()
184 self.wfile.flush()
185
185
186 class _httprequesthandleropenssl(_httprequesthandler):
186 class _httprequesthandleropenssl(_httprequesthandler):
187 """HTTPS handler based on pyOpenSSL"""
187 """HTTPS handler based on pyOpenSSL"""
188
188
189 url_scheme = 'https'
189 url_scheme = 'https'
190
190
191 @staticmethod
191 @staticmethod
192 def preparehttpserver(httpserver, ssl_cert):
192 def preparehttpserver(httpserver, ssl_cert):
193 try:
193 try:
194 import OpenSSL
194 import OpenSSL
195 OpenSSL.SSL.Context
195 OpenSSL.SSL.Context
196 except ImportError:
196 except ImportError:
197 raise util.Abort(_("SSL support is unavailable"))
197 raise util.Abort(_("SSL support is unavailable"))
198 ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
198 ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
199 ctx.use_privatekey_file(ssl_cert)
199 ctx.use_privatekey_file(ssl_cert)
200 ctx.use_certificate_file(ssl_cert)
200 ctx.use_certificate_file(ssl_cert)
201 sock = socket.socket(httpserver.address_family, httpserver.socket_type)
201 sock = socket.socket(httpserver.address_family, httpserver.socket_type)
202 httpserver.socket = OpenSSL.SSL.Connection(ctx, sock)
202 httpserver.socket = OpenSSL.SSL.Connection(ctx, sock)
203 httpserver.server_bind()
203 httpserver.server_bind()
204 httpserver.server_activate()
204 httpserver.server_activate()
205
205
206 def setup(self):
206 def setup(self):
207 self.connection = self.request
207 self.connection = self.request
208 self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
208 self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
209 self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
209 self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
210
210
211 def do_write(self):
211 def do_write(self):
212 import OpenSSL
212 import OpenSSL
213 try:
213 try:
214 _httprequesthandler.do_write(self)
214 _httprequesthandler.do_write(self)
215 except OpenSSL.SSL.SysCallError, inst:
215 except OpenSSL.SSL.SysCallError, inst:
216 if inst.args[0] != errno.EPIPE:
216 if inst.args[0] != errno.EPIPE:
217 raise
217 raise
218
218
219 def handle_one_request(self):
219 def handle_one_request(self):
220 import OpenSSL
220 import OpenSSL
221 try:
221 try:
222 _httprequesthandler.handle_one_request(self)
222 _httprequesthandler.handle_one_request(self)
223 except (OpenSSL.SSL.SysCallError, OpenSSL.SSL.ZeroReturnError):
223 except (OpenSSL.SSL.SysCallError, OpenSSL.SSL.ZeroReturnError):
224 self.close_connection = True
224 self.close_connection = True
225 pass
225 pass
226
226
227 class _httprequesthandlerssl(_httprequesthandler):
227 class _httprequesthandlerssl(_httprequesthandler):
228 """HTTPS handler based on Pythons ssl module (introduced in 2.6)"""
228 """HTTPS handler based on Pythons ssl module (introduced in 2.6)"""
229
229
230 url_scheme = 'https'
230 url_scheme = 'https'
231
231
232 @staticmethod
232 @staticmethod
233 def preparehttpserver(httpserver, ssl_cert):
233 def preparehttpserver(httpserver, ssl_cert):
234 try:
234 try:
235 import ssl
235 import ssl
236 ssl.wrap_socket
236 ssl.wrap_socket
237 except ImportError:
237 except ImportError:
238 raise util.Abort(_("SSL support is unavailable"))
238 raise util.Abort(_("SSL support is unavailable"))
239 httpserver.socket = ssl.wrap_socket(httpserver.socket, server_side=True,
239 httpserver.socket = ssl.wrap_socket(httpserver.socket, server_side=True,
240 certfile=ssl_cert, ssl_version=ssl.PROTOCOL_SSLv23)
240 certfile=ssl_cert, ssl_version=ssl.PROTOCOL_SSLv23)
241
241
242 def setup(self):
242 def setup(self):
243 self.connection = self.request
243 self.connection = self.request
244 self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
244 self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
245 self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
245 self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
246
246
247 try:
247 try:
248 from threading import activeCount
248 from threading import activeCount
249 _mixin = SocketServer.ThreadingMixIn
249 _mixin = SocketServer.ThreadingMixIn
250 except ImportError:
250 except ImportError:
251 if hasattr(os, "fork"):
251 if hasattr(os, "fork"):
252 _mixin = SocketServer.ForkingMixIn
252 _mixin = SocketServer.ForkingMixIn
253 else:
253 else:
254 class _mixin:
254 class _mixin(object):
255 pass
255 pass
256
256
257 def openlog(opt, default):
257 def openlog(opt, default):
258 if opt and opt != '-':
258 if opt and opt != '-':
259 return open(opt, 'a')
259 return open(opt, 'a')
260 return default
260 return default
261
261
262 class MercurialHTTPServer(object, _mixin, BaseHTTPServer.HTTPServer):
262 class MercurialHTTPServer(object, _mixin, BaseHTTPServer.HTTPServer):
263
263
264 # SO_REUSEADDR has broken semantics on windows
264 # SO_REUSEADDR has broken semantics on windows
265 if os.name == 'nt':
265 if os.name == 'nt':
266 allow_reuse_address = 0
266 allow_reuse_address = 0
267
267
268 def __init__(self, ui, app, addr, handler, **kwargs):
268 def __init__(self, ui, app, addr, handler, **kwargs):
269 BaseHTTPServer.HTTPServer.__init__(self, addr, handler, **kwargs)
269 BaseHTTPServer.HTTPServer.__init__(self, addr, handler, **kwargs)
270 self.daemon_threads = True
270 self.daemon_threads = True
271 self.application = app
271 self.application = app
272
272
273 handler.preparehttpserver(self, ui.config('web', 'certificate'))
273 handler.preparehttpserver(self, ui.config('web', 'certificate'))
274
274
275 prefix = ui.config('web', 'prefix', '')
275 prefix = ui.config('web', 'prefix', '')
276 if prefix:
276 if prefix:
277 prefix = '/' + prefix.strip('/')
277 prefix = '/' + prefix.strip('/')
278 self.prefix = prefix
278 self.prefix = prefix
279
279
280 alog = openlog(ui.config('web', 'accesslog', '-'), sys.stdout)
280 alog = openlog(ui.config('web', 'accesslog', '-'), sys.stdout)
281 elog = openlog(ui.config('web', 'errorlog', '-'), sys.stderr)
281 elog = openlog(ui.config('web', 'errorlog', '-'), sys.stderr)
282 self.accesslog = alog
282 self.accesslog = alog
283 self.errorlog = elog
283 self.errorlog = elog
284
284
285 self.addr, self.port = self.socket.getsockname()[0:2]
285 self.addr, self.port = self.socket.getsockname()[0:2]
286 self.fqaddr = socket.getfqdn(addr[0])
286 self.fqaddr = socket.getfqdn(addr[0])
287
287
288 class IPv6HTTPServer(MercurialHTTPServer):
288 class IPv6HTTPServer(MercurialHTTPServer):
289 address_family = getattr(socket, 'AF_INET6', None)
289 address_family = getattr(socket, 'AF_INET6', None)
290 def __init__(self, *args, **kwargs):
290 def __init__(self, *args, **kwargs):
291 if self.address_family is None:
291 if self.address_family is None:
292 raise error.RepoError(_('IPv6 is not available on this system'))
292 raise error.RepoError(_('IPv6 is not available on this system'))
293 super(IPv6HTTPServer, self).__init__(*args, **kwargs)
293 super(IPv6HTTPServer, self).__init__(*args, **kwargs)
294
294
295 def create_server(ui, app):
295 def create_server(ui, app):
296
296
297 if ui.config('web', 'certificate'):
297 if ui.config('web', 'certificate'):
298 if sys.version_info >= (2, 6):
298 if sys.version_info >= (2, 6):
299 handler = _httprequesthandlerssl
299 handler = _httprequesthandlerssl
300 else:
300 else:
301 handler = _httprequesthandleropenssl
301 handler = _httprequesthandleropenssl
302 else:
302 else:
303 handler = _httprequesthandler
303 handler = _httprequesthandler
304
304
305 if ui.configbool('web', 'ipv6'):
305 if ui.configbool('web', 'ipv6'):
306 cls = IPv6HTTPServer
306 cls = IPv6HTTPServer
307 else:
307 else:
308 cls = MercurialHTTPServer
308 cls = MercurialHTTPServer
309
309
310 # ugly hack due to python issue5853 (for threaded use)
310 # ugly hack due to python issue5853 (for threaded use)
311 import mimetypes; mimetypes.init()
311 import mimetypes; mimetypes.init()
312
312
313 address = ui.config('web', 'address', '')
313 address = ui.config('web', 'address', '')
314 port = util.getport(ui.config('web', 'port', 8000))
314 port = util.getport(ui.config('web', 'port', 8000))
315 try:
315 try:
316 return cls(ui, app, (address, port), handler)
316 return cls(ui, app, (address, port), handler)
317 except socket.error, inst:
317 except socket.error, inst:
318 raise util.Abort(_("cannot start server at '%s:%d': %s")
318 raise util.Abort(_("cannot start server at '%s:%d': %s")
319 % (address, port, inst.args[1]))
319 % (address, port, inst.args[1]))
@@ -1,765 +1,765 b''
1 # This library is free software; you can redistribute it and/or
1 # This library is free software; you can redistribute it and/or
2 # modify it under the terms of the GNU Lesser General Public
2 # modify it under the terms of the GNU Lesser General Public
3 # License as published by the Free Software Foundation; either
3 # License as published by the Free Software Foundation; either
4 # version 2.1 of the License, or (at your option) any later version.
4 # version 2.1 of the License, or (at your option) any later version.
5 #
5 #
6 # This library is distributed in the hope that it will be useful,
6 # This library is distributed in the hope that it will be useful,
7 # but WITHOUT ANY WARRANTY; without even the implied warranty of
7 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
8 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9 # Lesser General Public License for more details.
9 # Lesser General Public License for more details.
10 #
10 #
11 # You should have received a copy of the GNU Lesser General Public
11 # You should have received a copy of the GNU Lesser General Public
12 # License along with this library; if not, write to the
12 # License along with this library; if not, write to the
13 # Free Software Foundation, Inc.,
13 # Free Software Foundation, Inc.,
14 # 59 Temple Place, Suite 330,
14 # 59 Temple Place, Suite 330,
15 # Boston, MA 02111-1307 USA
15 # Boston, MA 02111-1307 USA
16
16
17 # This file is part of urlgrabber, a high-level cross-protocol url-grabber
17 # This file is part of urlgrabber, a high-level cross-protocol url-grabber
18 # Copyright 2002-2004 Michael D. Stenner, Ryan Tomayko
18 # Copyright 2002-2004 Michael D. Stenner, Ryan Tomayko
19
19
20 # Modified by Benoit Boissinot:
20 # Modified by Benoit Boissinot:
21 # - fix for digest auth (inspired from urllib2.py @ Python v2.4)
21 # - fix for digest auth (inspired from urllib2.py @ Python v2.4)
22 # Modified by Dirkjan Ochtman:
22 # Modified by Dirkjan Ochtman:
23 # - import md5 function from a local util module
23 # - import md5 function from a local util module
24 # Modified by Martin Geisler:
24 # Modified by Martin Geisler:
25 # - moved md5 function from local util module to this module
25 # - moved md5 function from local util module to this module
26 # Modified by Augie Fackler:
26 # Modified by Augie Fackler:
27 # - add safesend method and use it to prevent broken pipe errors
27 # - add safesend method and use it to prevent broken pipe errors
28 # on large POST requests
28 # on large POST requests
29
29
30 """An HTTP handler for urllib2 that supports HTTP 1.1 and keepalive.
30 """An HTTP handler for urllib2 that supports HTTP 1.1 and keepalive.
31
31
32 >>> import urllib2
32 >>> import urllib2
33 >>> from keepalive import HTTPHandler
33 >>> from keepalive import HTTPHandler
34 >>> keepalive_handler = HTTPHandler()
34 >>> keepalive_handler = HTTPHandler()
35 >>> opener = urllib2.build_opener(keepalive_handler)
35 >>> opener = urllib2.build_opener(keepalive_handler)
36 >>> urllib2.install_opener(opener)
36 >>> urllib2.install_opener(opener)
37 >>>
37 >>>
38 >>> fo = urllib2.urlopen('http://www.python.org')
38 >>> fo = urllib2.urlopen('http://www.python.org')
39
39
40 If a connection to a given host is requested, and all of the existing
40 If a connection to a given host is requested, and all of the existing
41 connections are still in use, another connection will be opened. If
41 connections are still in use, another connection will be opened. If
42 the handler tries to use an existing connection but it fails in some
42 the handler tries to use an existing connection but it fails in some
43 way, it will be closed and removed from the pool.
43 way, it will be closed and removed from the pool.
44
44
45 To remove the handler, simply re-run build_opener with no arguments, and
45 To remove the handler, simply re-run build_opener with no arguments, and
46 install that opener.
46 install that opener.
47
47
48 You can explicitly close connections by using the close_connection()
48 You can explicitly close connections by using the close_connection()
49 method of the returned file-like object (described below) or you can
49 method of the returned file-like object (described below) or you can
50 use the handler methods:
50 use the handler methods:
51
51
52 close_connection(host)
52 close_connection(host)
53 close_all()
53 close_all()
54 open_connections()
54 open_connections()
55
55
56 NOTE: using the close_connection and close_all methods of the handler
56 NOTE: using the close_connection and close_all methods of the handler
57 should be done with care when using multiple threads.
57 should be done with care when using multiple threads.
58 * there is nothing that prevents another thread from creating new
58 * there is nothing that prevents another thread from creating new
59 connections immediately after connections are closed
59 connections immediately after connections are closed
60 * no checks are done to prevent in-use connections from being closed
60 * no checks are done to prevent in-use connections from being closed
61
61
62 >>> keepalive_handler.close_all()
62 >>> keepalive_handler.close_all()
63
63
64 EXTRA ATTRIBUTES AND METHODS
64 EXTRA ATTRIBUTES AND METHODS
65
65
66 Upon a status of 200, the object returned has a few additional
66 Upon a status of 200, the object returned has a few additional
67 attributes and methods, which should not be used if you want to
67 attributes and methods, which should not be used if you want to
68 remain consistent with the normal urllib2-returned objects:
68 remain consistent with the normal urllib2-returned objects:
69
69
70 close_connection() - close the connection to the host
70 close_connection() - close the connection to the host
71 readlines() - you know, readlines()
71 readlines() - you know, readlines()
72 status - the return status (ie 404)
72 status - the return status (ie 404)
73 reason - english translation of status (ie 'File not found')
73 reason - english translation of status (ie 'File not found')
74
74
75 If you want the best of both worlds, use this inside an
75 If you want the best of both worlds, use this inside an
76 AttributeError-catching try:
76 AttributeError-catching try:
77
77
78 >>> try: status = fo.status
78 >>> try: status = fo.status
79 >>> except AttributeError: status = None
79 >>> except AttributeError: status = None
80
80
81 Unfortunately, these are ONLY there if status == 200, so it's not
81 Unfortunately, these are ONLY there if status == 200, so it's not
82 easy to distinguish between non-200 responses. The reason is that
82 easy to distinguish between non-200 responses. The reason is that
83 urllib2 tries to do clever things with error codes 301, 302, 401,
83 urllib2 tries to do clever things with error codes 301, 302, 401,
84 and 407, and it wraps the object upon return.
84 and 407, and it wraps the object upon return.
85
85
86 For python versions earlier than 2.4, you can avoid this fancy error
86 For python versions earlier than 2.4, you can avoid this fancy error
87 handling by setting the module-level global HANDLE_ERRORS to zero.
87 handling by setting the module-level global HANDLE_ERRORS to zero.
88 You see, prior to 2.4, it's the HTTP Handler's job to determine what
88 You see, prior to 2.4, it's the HTTP Handler's job to determine what
89 to handle specially, and what to just pass up. HANDLE_ERRORS == 0
89 to handle specially, and what to just pass up. HANDLE_ERRORS == 0
90 means "pass everything up". In python 2.4, however, this job no
90 means "pass everything up". In python 2.4, however, this job no
91 longer belongs to the HTTP Handler and is now done by a NEW handler,
91 longer belongs to the HTTP Handler and is now done by a NEW handler,
92 HTTPErrorProcessor. Here's the bottom line:
92 HTTPErrorProcessor. Here's the bottom line:
93
93
94 python version < 2.4
94 python version < 2.4
95 HANDLE_ERRORS == 1 (default) pass up 200, treat the rest as
95 HANDLE_ERRORS == 1 (default) pass up 200, treat the rest as
96 errors
96 errors
97 HANDLE_ERRORS == 0 pass everything up, error processing is
97 HANDLE_ERRORS == 0 pass everything up, error processing is
98 left to the calling code
98 left to the calling code
99 python version >= 2.4
99 python version >= 2.4
100 HANDLE_ERRORS == 1 pass up 200, treat the rest as errors
100 HANDLE_ERRORS == 1 pass up 200, treat the rest as errors
101 HANDLE_ERRORS == 0 (default) pass everything up, let the
101 HANDLE_ERRORS == 0 (default) pass everything up, let the
102 other handlers (specifically,
102 other handlers (specifically,
103 HTTPErrorProcessor) decide what to do
103 HTTPErrorProcessor) decide what to do
104
104
105 In practice, setting the variable either way makes little difference
105 In practice, setting the variable either way makes little difference
106 in python 2.4, so for the most consistent behavior across versions,
106 in python 2.4, so for the most consistent behavior across versions,
107 you probably just want to use the defaults, which will give you
107 you probably just want to use the defaults, which will give you
108 exceptions on errors.
108 exceptions on errors.
109
109
110 """
110 """
111
111
112 # $Id: keepalive.py,v 1.14 2006/04/04 21:00:32 mstenner Exp $
112 # $Id: keepalive.py,v 1.14 2006/04/04 21:00:32 mstenner Exp $
113
113
114 import errno
114 import errno
115 import httplib
115 import httplib
116 import socket
116 import socket
117 import thread
117 import thread
118 import urllib2
118 import urllib2
119
119
120 DEBUG = None
120 DEBUG = None
121
121
122 import sys
122 import sys
123 if sys.version_info < (2, 4):
123 if sys.version_info < (2, 4):
124 HANDLE_ERRORS = 1
124 HANDLE_ERRORS = 1
125 else: HANDLE_ERRORS = 0
125 else: HANDLE_ERRORS = 0
126
126
127 class ConnectionManager:
127 class ConnectionManager(object):
128 """
128 """
129 The connection manager must be able to:
129 The connection manager must be able to:
130 * keep track of all existing
130 * keep track of all existing
131 """
131 """
132 def __init__(self):
132 def __init__(self):
133 self._lock = thread.allocate_lock()
133 self._lock = thread.allocate_lock()
134 self._hostmap = {} # map hosts to a list of connections
134 self._hostmap = {} # map hosts to a list of connections
135 self._connmap = {} # map connections to host
135 self._connmap = {} # map connections to host
136 self._readymap = {} # map connection to ready state
136 self._readymap = {} # map connection to ready state
137
137
138 def add(self, host, connection, ready):
138 def add(self, host, connection, ready):
139 self._lock.acquire()
139 self._lock.acquire()
140 try:
140 try:
141 if not host in self._hostmap:
141 if not host in self._hostmap:
142 self._hostmap[host] = []
142 self._hostmap[host] = []
143 self._hostmap[host].append(connection)
143 self._hostmap[host].append(connection)
144 self._connmap[connection] = host
144 self._connmap[connection] = host
145 self._readymap[connection] = ready
145 self._readymap[connection] = ready
146 finally:
146 finally:
147 self._lock.release()
147 self._lock.release()
148
148
149 def remove(self, connection):
149 def remove(self, connection):
150 self._lock.acquire()
150 self._lock.acquire()
151 try:
151 try:
152 try:
152 try:
153 host = self._connmap[connection]
153 host = self._connmap[connection]
154 except KeyError:
154 except KeyError:
155 pass
155 pass
156 else:
156 else:
157 del self._connmap[connection]
157 del self._connmap[connection]
158 del self._readymap[connection]
158 del self._readymap[connection]
159 self._hostmap[host].remove(connection)
159 self._hostmap[host].remove(connection)
160 if not self._hostmap[host]: del self._hostmap[host]
160 if not self._hostmap[host]: del self._hostmap[host]
161 finally:
161 finally:
162 self._lock.release()
162 self._lock.release()
163
163
164 def set_ready(self, connection, ready):
164 def set_ready(self, connection, ready):
165 try:
165 try:
166 self._readymap[connection] = ready
166 self._readymap[connection] = ready
167 except KeyError:
167 except KeyError:
168 pass
168 pass
169
169
170 def get_ready_conn(self, host):
170 def get_ready_conn(self, host):
171 conn = None
171 conn = None
172 self._lock.acquire()
172 self._lock.acquire()
173 try:
173 try:
174 if host in self._hostmap:
174 if host in self._hostmap:
175 for c in self._hostmap[host]:
175 for c in self._hostmap[host]:
176 if self._readymap[c]:
176 if self._readymap[c]:
177 self._readymap[c] = 0
177 self._readymap[c] = 0
178 conn = c
178 conn = c
179 break
179 break
180 finally:
180 finally:
181 self._lock.release()
181 self._lock.release()
182 return conn
182 return conn
183
183
184 def get_all(self, host=None):
184 def get_all(self, host=None):
185 if host:
185 if host:
186 return list(self._hostmap.get(host, []))
186 return list(self._hostmap.get(host, []))
187 else:
187 else:
188 return dict(self._hostmap)
188 return dict(self._hostmap)
189
189
190 class KeepAliveHandler:
190 class KeepAliveHandler(object):
191 def __init__(self):
191 def __init__(self):
192 self._cm = ConnectionManager()
192 self._cm = ConnectionManager()
193
193
194 #### Connection Management
194 #### Connection Management
195 def open_connections(self):
195 def open_connections(self):
196 """return a list of connected hosts and the number of connections
196 """return a list of connected hosts and the number of connections
197 to each. [('foo.com:80', 2), ('bar.org', 1)]"""
197 to each. [('foo.com:80', 2), ('bar.org', 1)]"""
198 return [(host, len(li)) for (host, li) in self._cm.get_all().items()]
198 return [(host, len(li)) for (host, li) in self._cm.get_all().items()]
199
199
200 def close_connection(self, host):
200 def close_connection(self, host):
201 """close connection(s) to <host>
201 """close connection(s) to <host>
202 host is the host:port spec, as in 'www.cnn.com:8080' as passed in.
202 host is the host:port spec, as in 'www.cnn.com:8080' as passed in.
203 no error occurs if there is no connection to that host."""
203 no error occurs if there is no connection to that host."""
204 for h in self._cm.get_all(host):
204 for h in self._cm.get_all(host):
205 self._cm.remove(h)
205 self._cm.remove(h)
206 h.close()
206 h.close()
207
207
208 def close_all(self):
208 def close_all(self):
209 """close all open connections"""
209 """close all open connections"""
210 for host, conns in self._cm.get_all().iteritems():
210 for host, conns in self._cm.get_all().iteritems():
211 for h in conns:
211 for h in conns:
212 self._cm.remove(h)
212 self._cm.remove(h)
213 h.close()
213 h.close()
214
214
215 def _request_closed(self, request, host, connection):
215 def _request_closed(self, request, host, connection):
216 """tells us that this request is now closed and the the
216 """tells us that this request is now closed and the the
217 connection is ready for another request"""
217 connection is ready for another request"""
218 self._cm.set_ready(connection, 1)
218 self._cm.set_ready(connection, 1)
219
219
220 def _remove_connection(self, host, connection, close=0):
220 def _remove_connection(self, host, connection, close=0):
221 if close:
221 if close:
222 connection.close()
222 connection.close()
223 self._cm.remove(connection)
223 self._cm.remove(connection)
224
224
225 #### Transaction Execution
225 #### Transaction Execution
226 def http_open(self, req):
226 def http_open(self, req):
227 return self.do_open(HTTPConnection, req)
227 return self.do_open(HTTPConnection, req)
228
228
229 def do_open(self, http_class, req):
229 def do_open(self, http_class, req):
230 host = req.get_host()
230 host = req.get_host()
231 if not host:
231 if not host:
232 raise urllib2.URLError('no host given')
232 raise urllib2.URLError('no host given')
233
233
234 try:
234 try:
235 h = self._cm.get_ready_conn(host)
235 h = self._cm.get_ready_conn(host)
236 while h:
236 while h:
237 r = self._reuse_connection(h, req, host)
237 r = self._reuse_connection(h, req, host)
238
238
239 # if this response is non-None, then it worked and we're
239 # if this response is non-None, then it worked and we're
240 # done. Break out, skipping the else block.
240 # done. Break out, skipping the else block.
241 if r:
241 if r:
242 break
242 break
243
243
244 # connection is bad - possibly closed by server
244 # connection is bad - possibly closed by server
245 # discard it and ask for the next free connection
245 # discard it and ask for the next free connection
246 h.close()
246 h.close()
247 self._cm.remove(h)
247 self._cm.remove(h)
248 h = self._cm.get_ready_conn(host)
248 h = self._cm.get_ready_conn(host)
249 else:
249 else:
250 # no (working) free connections were found. Create a new one.
250 # no (working) free connections were found. Create a new one.
251 h = http_class(host)
251 h = http_class(host)
252 if DEBUG:
252 if DEBUG:
253 DEBUG.info("creating new connection to %s (%d)",
253 DEBUG.info("creating new connection to %s (%d)",
254 host, id(h))
254 host, id(h))
255 self._cm.add(host, h, 0)
255 self._cm.add(host, h, 0)
256 self._start_transaction(h, req)
256 self._start_transaction(h, req)
257 r = h.getresponse()
257 r = h.getresponse()
258 except (socket.error, httplib.HTTPException), err:
258 except (socket.error, httplib.HTTPException), err:
259 raise urllib2.URLError(err)
259 raise urllib2.URLError(err)
260
260
261 # if not a persistent connection, don't try to reuse it
261 # if not a persistent connection, don't try to reuse it
262 if r.will_close:
262 if r.will_close:
263 self._cm.remove(h)
263 self._cm.remove(h)
264
264
265 if DEBUG:
265 if DEBUG:
266 DEBUG.info("STATUS: %s, %s", r.status, r.reason)
266 DEBUG.info("STATUS: %s, %s", r.status, r.reason)
267 r._handler = self
267 r._handler = self
268 r._host = host
268 r._host = host
269 r._url = req.get_full_url()
269 r._url = req.get_full_url()
270 r._connection = h
270 r._connection = h
271 r.code = r.status
271 r.code = r.status
272 r.headers = r.msg
272 r.headers = r.msg
273 r.msg = r.reason
273 r.msg = r.reason
274
274
275 if r.status == 200 or not HANDLE_ERRORS:
275 if r.status == 200 or not HANDLE_ERRORS:
276 return r
276 return r
277 else:
277 else:
278 return self.parent.error('http', req, r,
278 return self.parent.error('http', req, r,
279 r.status, r.msg, r.headers)
279 r.status, r.msg, r.headers)
280
280
281 def _reuse_connection(self, h, req, host):
281 def _reuse_connection(self, h, req, host):
282 """start the transaction with a re-used connection
282 """start the transaction with a re-used connection
283 return a response object (r) upon success or None on failure.
283 return a response object (r) upon success or None on failure.
284 This DOES not close or remove bad connections in cases where
284 This DOES not close or remove bad connections in cases where
285 it returns. However, if an unexpected exception occurs, it
285 it returns. However, if an unexpected exception occurs, it
286 will close and remove the connection before re-raising.
286 will close and remove the connection before re-raising.
287 """
287 """
288 try:
288 try:
289 self._start_transaction(h, req)
289 self._start_transaction(h, req)
290 r = h.getresponse()
290 r = h.getresponse()
291 # note: just because we got something back doesn't mean it
291 # note: just because we got something back doesn't mean it
292 # worked. We'll check the version below, too.
292 # worked. We'll check the version below, too.
293 except (socket.error, httplib.HTTPException):
293 except (socket.error, httplib.HTTPException):
294 r = None
294 r = None
295 except:
295 except:
296 # adding this block just in case we've missed
296 # adding this block just in case we've missed
297 # something we will still raise the exception, but
297 # something we will still raise the exception, but
298 # lets try and close the connection and remove it
298 # lets try and close the connection and remove it
299 # first. We previously got into a nasty loop
299 # first. We previously got into a nasty loop
300 # where an exception was uncaught, and so the
300 # where an exception was uncaught, and so the
301 # connection stayed open. On the next try, the
301 # connection stayed open. On the next try, the
302 # same exception was raised, etc. The tradeoff is
302 # same exception was raised, etc. The tradeoff is
303 # that it's now possible this call will raise
303 # that it's now possible this call will raise
304 # a DIFFERENT exception
304 # a DIFFERENT exception
305 if DEBUG:
305 if DEBUG:
306 DEBUG.error("unexpected exception - closing "
306 DEBUG.error("unexpected exception - closing "
307 "connection to %s (%d)", host, id(h))
307 "connection to %s (%d)", host, id(h))
308 self._cm.remove(h)
308 self._cm.remove(h)
309 h.close()
309 h.close()
310 raise
310 raise
311
311
312 if r is None or r.version == 9:
312 if r is None or r.version == 9:
313 # httplib falls back to assuming HTTP 0.9 if it gets a
313 # httplib falls back to assuming HTTP 0.9 if it gets a
314 # bad header back. This is most likely to happen if
314 # bad header back. This is most likely to happen if
315 # the socket has been closed by the server since we
315 # the socket has been closed by the server since we
316 # last used the connection.
316 # last used the connection.
317 if DEBUG:
317 if DEBUG:
318 DEBUG.info("failed to re-use connection to %s (%d)",
318 DEBUG.info("failed to re-use connection to %s (%d)",
319 host, id(h))
319 host, id(h))
320 r = None
320 r = None
321 else:
321 else:
322 if DEBUG:
322 if DEBUG:
323 DEBUG.info("re-using connection to %s (%d)", host, id(h))
323 DEBUG.info("re-using connection to %s (%d)", host, id(h))
324
324
325 return r
325 return r
326
326
327 def _start_transaction(self, h, req):
327 def _start_transaction(self, h, req):
328 # What follows mostly reimplements HTTPConnection.request()
328 # What follows mostly reimplements HTTPConnection.request()
329 # except it adds self.parent.addheaders in the mix.
329 # except it adds self.parent.addheaders in the mix.
330 headers = req.headers.copy()
330 headers = req.headers.copy()
331 if sys.version_info >= (2, 4):
331 if sys.version_info >= (2, 4):
332 headers.update(req.unredirected_hdrs)
332 headers.update(req.unredirected_hdrs)
333 headers.update(self.parent.addheaders)
333 headers.update(self.parent.addheaders)
334 headers = dict((n.lower(), v) for n, v in headers.items())
334 headers = dict((n.lower(), v) for n, v in headers.items())
335 skipheaders = {}
335 skipheaders = {}
336 for n in ('host', 'accept-encoding'):
336 for n in ('host', 'accept-encoding'):
337 if n in headers:
337 if n in headers:
338 skipheaders['skip_' + n.replace('-', '_')] = 1
338 skipheaders['skip_' + n.replace('-', '_')] = 1
339 try:
339 try:
340 if req.has_data():
340 if req.has_data():
341 data = req.get_data()
341 data = req.get_data()
342 h.putrequest('POST', req.get_selector(), **skipheaders)
342 h.putrequest('POST', req.get_selector(), **skipheaders)
343 if 'content-type' not in headers:
343 if 'content-type' not in headers:
344 h.putheader('Content-type',
344 h.putheader('Content-type',
345 'application/x-www-form-urlencoded')
345 'application/x-www-form-urlencoded')
346 if 'content-length' not in headers:
346 if 'content-length' not in headers:
347 h.putheader('Content-length', '%d' % len(data))
347 h.putheader('Content-length', '%d' % len(data))
348 else:
348 else:
349 h.putrequest('GET', req.get_selector(), **skipheaders)
349 h.putrequest('GET', req.get_selector(), **skipheaders)
350 except (socket.error), err:
350 except (socket.error), err:
351 raise urllib2.URLError(err)
351 raise urllib2.URLError(err)
352 for k, v in headers.items():
352 for k, v in headers.items():
353 h.putheader(k, v)
353 h.putheader(k, v)
354 h.endheaders()
354 h.endheaders()
355 if req.has_data():
355 if req.has_data():
356 h.send(data)
356 h.send(data)
357
357
358 class HTTPHandler(KeepAliveHandler, urllib2.HTTPHandler):
358 class HTTPHandler(KeepAliveHandler, urllib2.HTTPHandler):
359 pass
359 pass
360
360
361 class HTTPResponse(httplib.HTTPResponse):
361 class HTTPResponse(httplib.HTTPResponse):
362 # we need to subclass HTTPResponse in order to
362 # we need to subclass HTTPResponse in order to
363 # 1) add readline() and readlines() methods
363 # 1) add readline() and readlines() methods
364 # 2) add close_connection() methods
364 # 2) add close_connection() methods
365 # 3) add info() and geturl() methods
365 # 3) add info() and geturl() methods
366
366
367 # in order to add readline(), read must be modified to deal with a
367 # in order to add readline(), read must be modified to deal with a
368 # buffer. example: readline must read a buffer and then spit back
368 # buffer. example: readline must read a buffer and then spit back
369 # one line at a time. The only real alternative is to read one
369 # one line at a time. The only real alternative is to read one
370 # BYTE at a time (ick). Once something has been read, it can't be
370 # BYTE at a time (ick). Once something has been read, it can't be
371 # put back (ok, maybe it can, but that's even uglier than this),
371 # put back (ok, maybe it can, but that's even uglier than this),
372 # so if you THEN do a normal read, you must first take stuff from
372 # so if you THEN do a normal read, you must first take stuff from
373 # the buffer.
373 # the buffer.
374
374
375 # the read method wraps the original to accomodate buffering,
375 # the read method wraps the original to accomodate buffering,
376 # although read() never adds to the buffer.
376 # although read() never adds to the buffer.
377 # Both readline and readlines have been stolen with almost no
377 # Both readline and readlines have been stolen with almost no
378 # modification from socket.py
378 # modification from socket.py
379
379
380
380
381 def __init__(self, sock, debuglevel=0, strict=0, method=None):
381 def __init__(self, sock, debuglevel=0, strict=0, method=None):
382 if method: # the httplib in python 2.3 uses the method arg
382 if method: # the httplib in python 2.3 uses the method arg
383 httplib.HTTPResponse.__init__(self, sock, debuglevel, method)
383 httplib.HTTPResponse.__init__(self, sock, debuglevel, method)
384 else: # 2.2 doesn't
384 else: # 2.2 doesn't
385 httplib.HTTPResponse.__init__(self, sock, debuglevel)
385 httplib.HTTPResponse.__init__(self, sock, debuglevel)
386 self.fileno = sock.fileno
386 self.fileno = sock.fileno
387 self.code = None
387 self.code = None
388 self._rbuf = ''
388 self._rbuf = ''
389 self._rbufsize = 8096
389 self._rbufsize = 8096
390 self._handler = None # inserted by the handler later
390 self._handler = None # inserted by the handler later
391 self._host = None # (same)
391 self._host = None # (same)
392 self._url = None # (same)
392 self._url = None # (same)
393 self._connection = None # (same)
393 self._connection = None # (same)
394
394
395 _raw_read = httplib.HTTPResponse.read
395 _raw_read = httplib.HTTPResponse.read
396
396
397 def close(self):
397 def close(self):
398 if self.fp:
398 if self.fp:
399 self.fp.close()
399 self.fp.close()
400 self.fp = None
400 self.fp = None
401 if self._handler:
401 if self._handler:
402 self._handler._request_closed(self, self._host,
402 self._handler._request_closed(self, self._host,
403 self._connection)
403 self._connection)
404
404
405 def close_connection(self):
405 def close_connection(self):
406 self._handler._remove_connection(self._host, self._connection, close=1)
406 self._handler._remove_connection(self._host, self._connection, close=1)
407 self.close()
407 self.close()
408
408
409 def info(self):
409 def info(self):
410 return self.headers
410 return self.headers
411
411
412 def geturl(self):
412 def geturl(self):
413 return self._url
413 return self._url
414
414
415 def read(self, amt=None):
415 def read(self, amt=None):
416 # the _rbuf test is only in this first if for speed. It's not
416 # the _rbuf test is only in this first if for speed. It's not
417 # logically necessary
417 # logically necessary
418 if self._rbuf and not amt is None:
418 if self._rbuf and not amt is None:
419 L = len(self._rbuf)
419 L = len(self._rbuf)
420 if amt > L:
420 if amt > L:
421 amt -= L
421 amt -= L
422 else:
422 else:
423 s = self._rbuf[:amt]
423 s = self._rbuf[:amt]
424 self._rbuf = self._rbuf[amt:]
424 self._rbuf = self._rbuf[amt:]
425 return s
425 return s
426
426
427 s = self._rbuf + self._raw_read(amt)
427 s = self._rbuf + self._raw_read(amt)
428 self._rbuf = ''
428 self._rbuf = ''
429 return s
429 return s
430
430
431 # stolen from Python SVN #68532 to fix issue1088
431 # stolen from Python SVN #68532 to fix issue1088
432 def _read_chunked(self, amt):
432 def _read_chunked(self, amt):
433 chunk_left = self.chunk_left
433 chunk_left = self.chunk_left
434 value = ''
434 value = ''
435
435
436 # XXX This accumulates chunks by repeated string concatenation,
436 # XXX This accumulates chunks by repeated string concatenation,
437 # which is not efficient as the number or size of chunks gets big.
437 # which is not efficient as the number or size of chunks gets big.
438 while True:
438 while True:
439 if chunk_left is None:
439 if chunk_left is None:
440 line = self.fp.readline()
440 line = self.fp.readline()
441 i = line.find(';')
441 i = line.find(';')
442 if i >= 0:
442 if i >= 0:
443 line = line[:i] # strip chunk-extensions
443 line = line[:i] # strip chunk-extensions
444 try:
444 try:
445 chunk_left = int(line, 16)
445 chunk_left = int(line, 16)
446 except ValueError:
446 except ValueError:
447 # close the connection as protocol synchronisation is
447 # close the connection as protocol synchronisation is
448 # probably lost
448 # probably lost
449 self.close()
449 self.close()
450 raise httplib.IncompleteRead(value)
450 raise httplib.IncompleteRead(value)
451 if chunk_left == 0:
451 if chunk_left == 0:
452 break
452 break
453 if amt is None:
453 if amt is None:
454 value += self._safe_read(chunk_left)
454 value += self._safe_read(chunk_left)
455 elif amt < chunk_left:
455 elif amt < chunk_left:
456 value += self._safe_read(amt)
456 value += self._safe_read(amt)
457 self.chunk_left = chunk_left - amt
457 self.chunk_left = chunk_left - amt
458 return value
458 return value
459 elif amt == chunk_left:
459 elif amt == chunk_left:
460 value += self._safe_read(amt)
460 value += self._safe_read(amt)
461 self._safe_read(2) # toss the CRLF at the end of the chunk
461 self._safe_read(2) # toss the CRLF at the end of the chunk
462 self.chunk_left = None
462 self.chunk_left = None
463 return value
463 return value
464 else:
464 else:
465 value += self._safe_read(chunk_left)
465 value += self._safe_read(chunk_left)
466 amt -= chunk_left
466 amt -= chunk_left
467
467
468 # we read the whole chunk, get another
468 # we read the whole chunk, get another
469 self._safe_read(2) # toss the CRLF at the end of the chunk
469 self._safe_read(2) # toss the CRLF at the end of the chunk
470 chunk_left = None
470 chunk_left = None
471
471
472 # read and discard trailer up to the CRLF terminator
472 # read and discard trailer up to the CRLF terminator
473 ### note: we shouldn't have any trailers!
473 ### note: we shouldn't have any trailers!
474 while True:
474 while True:
475 line = self.fp.readline()
475 line = self.fp.readline()
476 if not line:
476 if not line:
477 # a vanishingly small number of sites EOF without
477 # a vanishingly small number of sites EOF without
478 # sending the trailer
478 # sending the trailer
479 break
479 break
480 if line == '\r\n':
480 if line == '\r\n':
481 break
481 break
482
482
483 # we read everything; close the "file"
483 # we read everything; close the "file"
484 self.close()
484 self.close()
485
485
486 return value
486 return value
487
487
488 def readline(self, limit=-1):
488 def readline(self, limit=-1):
489 i = self._rbuf.find('\n')
489 i = self._rbuf.find('\n')
490 while i < 0 and not (0 < limit <= len(self._rbuf)):
490 while i < 0 and not (0 < limit <= len(self._rbuf)):
491 new = self._raw_read(self._rbufsize)
491 new = self._raw_read(self._rbufsize)
492 if not new:
492 if not new:
493 break
493 break
494 i = new.find('\n')
494 i = new.find('\n')
495 if i >= 0:
495 if i >= 0:
496 i = i + len(self._rbuf)
496 i = i + len(self._rbuf)
497 self._rbuf = self._rbuf + new
497 self._rbuf = self._rbuf + new
498 if i < 0:
498 if i < 0:
499 i = len(self._rbuf)
499 i = len(self._rbuf)
500 else:
500 else:
501 i = i + 1
501 i = i + 1
502 if 0 <= limit < len(self._rbuf):
502 if 0 <= limit < len(self._rbuf):
503 i = limit
503 i = limit
504 data, self._rbuf = self._rbuf[:i], self._rbuf[i:]
504 data, self._rbuf = self._rbuf[:i], self._rbuf[i:]
505 return data
505 return data
506
506
507 def readlines(self, sizehint = 0):
507 def readlines(self, sizehint = 0):
508 total = 0
508 total = 0
509 list = []
509 list = []
510 while True:
510 while True:
511 line = self.readline()
511 line = self.readline()
512 if not line:
512 if not line:
513 break
513 break
514 list.append(line)
514 list.append(line)
515 total += len(line)
515 total += len(line)
516 if sizehint and total >= sizehint:
516 if sizehint and total >= sizehint:
517 break
517 break
518 return list
518 return list
519
519
520 def safesend(self, str):
520 def safesend(self, str):
521 """Send `str' to the server.
521 """Send `str' to the server.
522
522
523 Shamelessly ripped off from httplib to patch a bad behavior.
523 Shamelessly ripped off from httplib to patch a bad behavior.
524 """
524 """
525 # _broken_pipe_resp is an attribute we set in this function
525 # _broken_pipe_resp is an attribute we set in this function
526 # if the socket is closed while we're sending data but
526 # if the socket is closed while we're sending data but
527 # the server sent us a response before hanging up.
527 # the server sent us a response before hanging up.
528 # In that case, we want to pretend to send the rest of the
528 # In that case, we want to pretend to send the rest of the
529 # outgoing data, and then let the user use getresponse()
529 # outgoing data, and then let the user use getresponse()
530 # (which we wrap) to get this last response before
530 # (which we wrap) to get this last response before
531 # opening a new socket.
531 # opening a new socket.
532 if getattr(self, '_broken_pipe_resp', None) is not None:
532 if getattr(self, '_broken_pipe_resp', None) is not None:
533 return
533 return
534
534
535 if self.sock is None:
535 if self.sock is None:
536 if self.auto_open:
536 if self.auto_open:
537 self.connect()
537 self.connect()
538 else:
538 else:
539 raise httplib.NotConnected()
539 raise httplib.NotConnected()
540
540
541 # send the data to the server. if we get a broken pipe, then close
541 # send the data to the server. if we get a broken pipe, then close
542 # the socket. we want to reconnect when somebody tries to send again.
542 # the socket. we want to reconnect when somebody tries to send again.
543 #
543 #
544 # NOTE: we DO propagate the error, though, because we cannot simply
544 # NOTE: we DO propagate the error, though, because we cannot simply
545 # ignore the error... the caller will know if they can retry.
545 # ignore the error... the caller will know if they can retry.
546 if self.debuglevel > 0:
546 if self.debuglevel > 0:
547 print "send:", repr(str)
547 print "send:", repr(str)
548 try:
548 try:
549 blocksize = 8192
549 blocksize = 8192
550 if hasattr(str,'read') :
550 if hasattr(str,'read') :
551 if self.debuglevel > 0:
551 if self.debuglevel > 0:
552 print "sendIng a read()able"
552 print "sendIng a read()able"
553 data = str.read(blocksize)
553 data = str.read(blocksize)
554 while data:
554 while data:
555 self.sock.sendall(data)
555 self.sock.sendall(data)
556 data = str.read(blocksize)
556 data = str.read(blocksize)
557 else:
557 else:
558 self.sock.sendall(str)
558 self.sock.sendall(str)
559 except socket.error, v:
559 except socket.error, v:
560 reraise = True
560 reraise = True
561 if v[0] == errno.EPIPE: # Broken pipe
561 if v[0] == errno.EPIPE: # Broken pipe
562 if self._HTTPConnection__state == httplib._CS_REQ_SENT:
562 if self._HTTPConnection__state == httplib._CS_REQ_SENT:
563 self._broken_pipe_resp = None
563 self._broken_pipe_resp = None
564 self._broken_pipe_resp = self.getresponse()
564 self._broken_pipe_resp = self.getresponse()
565 reraise = False
565 reraise = False
566 self.close()
566 self.close()
567 if reraise:
567 if reraise:
568 raise
568 raise
569
569
570 def wrapgetresponse(cls):
570 def wrapgetresponse(cls):
571 """Wraps getresponse in cls with a broken-pipe sane version.
571 """Wraps getresponse in cls with a broken-pipe sane version.
572 """
572 """
573 def safegetresponse(self):
573 def safegetresponse(self):
574 # In safesend() we might set the _broken_pipe_resp
574 # In safesend() we might set the _broken_pipe_resp
575 # attribute, in which case the socket has already
575 # attribute, in which case the socket has already
576 # been closed and we just need to give them the response
576 # been closed and we just need to give them the response
577 # back. Otherwise, we use the normal response path.
577 # back. Otherwise, we use the normal response path.
578 r = getattr(self, '_broken_pipe_resp', None)
578 r = getattr(self, '_broken_pipe_resp', None)
579 if r is not None:
579 if r is not None:
580 return r
580 return r
581 return cls.getresponse(self)
581 return cls.getresponse(self)
582 safegetresponse.__doc__ = cls.getresponse.__doc__
582 safegetresponse.__doc__ = cls.getresponse.__doc__
583 return safegetresponse
583 return safegetresponse
584
584
585 class HTTPConnection(httplib.HTTPConnection):
585 class HTTPConnection(httplib.HTTPConnection):
586 # use the modified response class
586 # use the modified response class
587 response_class = HTTPResponse
587 response_class = HTTPResponse
588 send = safesend
588 send = safesend
589 getresponse = wrapgetresponse(httplib.HTTPConnection)
589 getresponse = wrapgetresponse(httplib.HTTPConnection)
590
590
591
591
592 #########################################################################
592 #########################################################################
593 ##### TEST FUNCTIONS
593 ##### TEST FUNCTIONS
594 #########################################################################
594 #########################################################################
595
595
596 def error_handler(url):
596 def error_handler(url):
597 global HANDLE_ERRORS
597 global HANDLE_ERRORS
598 orig = HANDLE_ERRORS
598 orig = HANDLE_ERRORS
599 keepalive_handler = HTTPHandler()
599 keepalive_handler = HTTPHandler()
600 opener = urllib2.build_opener(keepalive_handler)
600 opener = urllib2.build_opener(keepalive_handler)
601 urllib2.install_opener(opener)
601 urllib2.install_opener(opener)
602 pos = {0: 'off', 1: 'on'}
602 pos = {0: 'off', 1: 'on'}
603 for i in (0, 1):
603 for i in (0, 1):
604 print " fancy error handling %s (HANDLE_ERRORS = %i)" % (pos[i], i)
604 print " fancy error handling %s (HANDLE_ERRORS = %i)" % (pos[i], i)
605 HANDLE_ERRORS = i
605 HANDLE_ERRORS = i
606 try:
606 try:
607 fo = urllib2.urlopen(url)
607 fo = urllib2.urlopen(url)
608 fo.read()
608 fo.read()
609 fo.close()
609 fo.close()
610 try:
610 try:
611 status, reason = fo.status, fo.reason
611 status, reason = fo.status, fo.reason
612 except AttributeError:
612 except AttributeError:
613 status, reason = None, None
613 status, reason = None, None
614 except IOError, e:
614 except IOError, e:
615 print " EXCEPTION: %s" % e
615 print " EXCEPTION: %s" % e
616 raise
616 raise
617 else:
617 else:
618 print " status = %s, reason = %s" % (status, reason)
618 print " status = %s, reason = %s" % (status, reason)
619 HANDLE_ERRORS = orig
619 HANDLE_ERRORS = orig
620 hosts = keepalive_handler.open_connections()
620 hosts = keepalive_handler.open_connections()
621 print "open connections:", hosts
621 print "open connections:", hosts
622 keepalive_handler.close_all()
622 keepalive_handler.close_all()
623
623
624 def md5(s):
624 def md5(s):
625 try:
625 try:
626 from hashlib import md5 as _md5
626 from hashlib import md5 as _md5
627 except ImportError:
627 except ImportError:
628 from md5 import md5 as _md5
628 from md5 import md5 as _md5
629 global md5
629 global md5
630 md5 = _md5
630 md5 = _md5
631 return _md5(s)
631 return _md5(s)
632
632
633 def continuity(url):
633 def continuity(url):
634 format = '%25s: %s'
634 format = '%25s: %s'
635
635
636 # first fetch the file with the normal http handler
636 # first fetch the file with the normal http handler
637 opener = urllib2.build_opener()
637 opener = urllib2.build_opener()
638 urllib2.install_opener(opener)
638 urllib2.install_opener(opener)
639 fo = urllib2.urlopen(url)
639 fo = urllib2.urlopen(url)
640 foo = fo.read()
640 foo = fo.read()
641 fo.close()
641 fo.close()
642 m = md5.new(foo)
642 m = md5.new(foo)
643 print format % ('normal urllib', m.hexdigest())
643 print format % ('normal urllib', m.hexdigest())
644
644
645 # now install the keepalive handler and try again
645 # now install the keepalive handler and try again
646 opener = urllib2.build_opener(HTTPHandler())
646 opener = urllib2.build_opener(HTTPHandler())
647 urllib2.install_opener(opener)
647 urllib2.install_opener(opener)
648
648
649 fo = urllib2.urlopen(url)
649 fo = urllib2.urlopen(url)
650 foo = fo.read()
650 foo = fo.read()
651 fo.close()
651 fo.close()
652 m = md5.new(foo)
652 m = md5.new(foo)
653 print format % ('keepalive read', m.hexdigest())
653 print format % ('keepalive read', m.hexdigest())
654
654
655 fo = urllib2.urlopen(url)
655 fo = urllib2.urlopen(url)
656 foo = ''
656 foo = ''
657 while True:
657 while True:
658 f = fo.readline()
658 f = fo.readline()
659 if f:
659 if f:
660 foo = foo + f
660 foo = foo + f
661 else: break
661 else: break
662 fo.close()
662 fo.close()
663 m = md5.new(foo)
663 m = md5.new(foo)
664 print format % ('keepalive readline', m.hexdigest())
664 print format % ('keepalive readline', m.hexdigest())
665
665
666 def comp(N, url):
666 def comp(N, url):
667 print ' making %i connections to:\n %s' % (N, url)
667 print ' making %i connections to:\n %s' % (N, url)
668
668
669 sys.stdout.write(' first using the normal urllib handlers')
669 sys.stdout.write(' first using the normal urllib handlers')
670 # first use normal opener
670 # first use normal opener
671 opener = urllib2.build_opener()
671 opener = urllib2.build_opener()
672 urllib2.install_opener(opener)
672 urllib2.install_opener(opener)
673 t1 = fetch(N, url)
673 t1 = fetch(N, url)
674 print ' TIME: %.3f s' % t1
674 print ' TIME: %.3f s' % t1
675
675
676 sys.stdout.write(' now using the keepalive handler ')
676 sys.stdout.write(' now using the keepalive handler ')
677 # now install the keepalive handler and try again
677 # now install the keepalive handler and try again
678 opener = urllib2.build_opener(HTTPHandler())
678 opener = urllib2.build_opener(HTTPHandler())
679 urllib2.install_opener(opener)
679 urllib2.install_opener(opener)
680 t2 = fetch(N, url)
680 t2 = fetch(N, url)
681 print ' TIME: %.3f s' % t2
681 print ' TIME: %.3f s' % t2
682 print ' improvement factor: %.2f' % (t1 / t2)
682 print ' improvement factor: %.2f' % (t1 / t2)
683
683
684 def fetch(N, url, delay=0):
684 def fetch(N, url, delay=0):
685 import time
685 import time
686 lens = []
686 lens = []
687 starttime = time.time()
687 starttime = time.time()
688 for i in range(N):
688 for i in range(N):
689 if delay and i > 0:
689 if delay and i > 0:
690 time.sleep(delay)
690 time.sleep(delay)
691 fo = urllib2.urlopen(url)
691 fo = urllib2.urlopen(url)
692 foo = fo.read()
692 foo = fo.read()
693 fo.close()
693 fo.close()
694 lens.append(len(foo))
694 lens.append(len(foo))
695 diff = time.time() - starttime
695 diff = time.time() - starttime
696
696
697 j = 0
697 j = 0
698 for i in lens[1:]:
698 for i in lens[1:]:
699 j = j + 1
699 j = j + 1
700 if not i == lens[0]:
700 if not i == lens[0]:
701 print "WARNING: inconsistent length on read %i: %i" % (j, i)
701 print "WARNING: inconsistent length on read %i: %i" % (j, i)
702
702
703 return diff
703 return diff
704
704
705 def test_timeout(url):
705 def test_timeout(url):
706 global DEBUG
706 global DEBUG
707 dbbackup = DEBUG
707 dbbackup = DEBUG
708 class FakeLogger:
708 class FakeLogger(object):
709 def debug(self, msg, *args):
709 def debug(self, msg, *args):
710 print msg % args
710 print msg % args
711 info = warning = error = debug
711 info = warning = error = debug
712 DEBUG = FakeLogger()
712 DEBUG = FakeLogger()
713 print " fetching the file to establish a connection"
713 print " fetching the file to establish a connection"
714 fo = urllib2.urlopen(url)
714 fo = urllib2.urlopen(url)
715 data1 = fo.read()
715 data1 = fo.read()
716 fo.close()
716 fo.close()
717
717
718 i = 20
718 i = 20
719 print " waiting %i seconds for the server to close the connection" % i
719 print " waiting %i seconds for the server to close the connection" % i
720 while i > 0:
720 while i > 0:
721 sys.stdout.write('\r %2i' % i)
721 sys.stdout.write('\r %2i' % i)
722 sys.stdout.flush()
722 sys.stdout.flush()
723 time.sleep(1)
723 time.sleep(1)
724 i -= 1
724 i -= 1
725 sys.stderr.write('\r')
725 sys.stderr.write('\r')
726
726
727 print " fetching the file a second time"
727 print " fetching the file a second time"
728 fo = urllib2.urlopen(url)
728 fo = urllib2.urlopen(url)
729 data2 = fo.read()
729 data2 = fo.read()
730 fo.close()
730 fo.close()
731
731
732 if data1 == data2:
732 if data1 == data2:
733 print ' data are identical'
733 print ' data are identical'
734 else:
734 else:
735 print ' ERROR: DATA DIFFER'
735 print ' ERROR: DATA DIFFER'
736
736
737 DEBUG = dbbackup
737 DEBUG = dbbackup
738
738
739
739
740 def test(url, N=10):
740 def test(url, N=10):
741 print "checking error hander (do this on a non-200)"
741 print "checking error hander (do this on a non-200)"
742 try: error_handler(url)
742 try: error_handler(url)
743 except IOError:
743 except IOError:
744 print "exiting - exception will prevent further tests"
744 print "exiting - exception will prevent further tests"
745 sys.exit()
745 sys.exit()
746 print
746 print
747 print "performing continuity test (making sure stuff isn't corrupted)"
747 print "performing continuity test (making sure stuff isn't corrupted)"
748 continuity(url)
748 continuity(url)
749 print
749 print
750 print "performing speed comparison"
750 print "performing speed comparison"
751 comp(N, url)
751 comp(N, url)
752 print
752 print
753 print "performing dropped-connection check"
753 print "performing dropped-connection check"
754 test_timeout(url)
754 test_timeout(url)
755
755
756 if __name__ == '__main__':
756 if __name__ == '__main__':
757 import time
757 import time
758 import sys
758 import sys
759 try:
759 try:
760 N = int(sys.argv[1])
760 N = int(sys.argv[1])
761 url = sys.argv[2]
761 url = sys.argv[2]
762 except:
762 except:
763 print "%s <integer> <url>" % sys.argv[0]
763 print "%s <integer> <url>" % sys.argv[0]
764 else:
764 else:
765 test(url, N)
765 test(url, N)
@@ -1,1859 +1,1859 b''
1 # patch.py - patch file parsing routines
1 # patch.py - patch file parsing routines
2 #
2 #
3 # Copyright 2006 Brendan Cully <brendan@kublai.com>
3 # Copyright 2006 Brendan Cully <brendan@kublai.com>
4 # Copyright 2007 Chris Mason <chris.mason@oracle.com>
4 # Copyright 2007 Chris Mason <chris.mason@oracle.com>
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 import cStringIO, email.Parser, os, errno, re
9 import cStringIO, email.Parser, os, errno, re
10 import tempfile, zlib, shutil
10 import tempfile, zlib, shutil
11
11
12 from i18n import _
12 from i18n import _
13 from node import hex, nullid, short
13 from node import hex, nullid, short
14 import base85, mdiff, scmutil, util, diffhelpers, copies, encoding, error
14 import base85, mdiff, scmutil, util, diffhelpers, copies, encoding, error
15 import context
15 import context
16
16
17 gitre = re.compile('diff --git a/(.*) b/(.*)')
17 gitre = re.compile('diff --git a/(.*) b/(.*)')
18
18
19 class PatchError(Exception):
19 class PatchError(Exception):
20 pass
20 pass
21
21
22
22
23 # public functions
23 # public functions
24
24
25 def split(stream):
25 def split(stream):
26 '''return an iterator of individual patches from a stream'''
26 '''return an iterator of individual patches from a stream'''
27 def isheader(line, inheader):
27 def isheader(line, inheader):
28 if inheader and line[0] in (' ', '\t'):
28 if inheader and line[0] in (' ', '\t'):
29 # continuation
29 # continuation
30 return True
30 return True
31 if line[0] in (' ', '-', '+'):
31 if line[0] in (' ', '-', '+'):
32 # diff line - don't check for header pattern in there
32 # diff line - don't check for header pattern in there
33 return False
33 return False
34 l = line.split(': ', 1)
34 l = line.split(': ', 1)
35 return len(l) == 2 and ' ' not in l[0]
35 return len(l) == 2 and ' ' not in l[0]
36
36
37 def chunk(lines):
37 def chunk(lines):
38 return cStringIO.StringIO(''.join(lines))
38 return cStringIO.StringIO(''.join(lines))
39
39
40 def hgsplit(stream, cur):
40 def hgsplit(stream, cur):
41 inheader = True
41 inheader = True
42
42
43 for line in stream:
43 for line in stream:
44 if not line.strip():
44 if not line.strip():
45 inheader = False
45 inheader = False
46 if not inheader and line.startswith('# HG changeset patch'):
46 if not inheader and line.startswith('# HG changeset patch'):
47 yield chunk(cur)
47 yield chunk(cur)
48 cur = []
48 cur = []
49 inheader = True
49 inheader = True
50
50
51 cur.append(line)
51 cur.append(line)
52
52
53 if cur:
53 if cur:
54 yield chunk(cur)
54 yield chunk(cur)
55
55
56 def mboxsplit(stream, cur):
56 def mboxsplit(stream, cur):
57 for line in stream:
57 for line in stream:
58 if line.startswith('From '):
58 if line.startswith('From '):
59 for c in split(chunk(cur[1:])):
59 for c in split(chunk(cur[1:])):
60 yield c
60 yield c
61 cur = []
61 cur = []
62
62
63 cur.append(line)
63 cur.append(line)
64
64
65 if cur:
65 if cur:
66 for c in split(chunk(cur[1:])):
66 for c in split(chunk(cur[1:])):
67 yield c
67 yield c
68
68
69 def mimesplit(stream, cur):
69 def mimesplit(stream, cur):
70 def msgfp(m):
70 def msgfp(m):
71 fp = cStringIO.StringIO()
71 fp = cStringIO.StringIO()
72 g = email.Generator.Generator(fp, mangle_from_=False)
72 g = email.Generator.Generator(fp, mangle_from_=False)
73 g.flatten(m)
73 g.flatten(m)
74 fp.seek(0)
74 fp.seek(0)
75 return fp
75 return fp
76
76
77 for line in stream:
77 for line in stream:
78 cur.append(line)
78 cur.append(line)
79 c = chunk(cur)
79 c = chunk(cur)
80
80
81 m = email.Parser.Parser().parse(c)
81 m = email.Parser.Parser().parse(c)
82 if not m.is_multipart():
82 if not m.is_multipart():
83 yield msgfp(m)
83 yield msgfp(m)
84 else:
84 else:
85 ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
85 ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
86 for part in m.walk():
86 for part in m.walk():
87 ct = part.get_content_type()
87 ct = part.get_content_type()
88 if ct not in ok_types:
88 if ct not in ok_types:
89 continue
89 continue
90 yield msgfp(part)
90 yield msgfp(part)
91
91
92 def headersplit(stream, cur):
92 def headersplit(stream, cur):
93 inheader = False
93 inheader = False
94
94
95 for line in stream:
95 for line in stream:
96 if not inheader and isheader(line, inheader):
96 if not inheader and isheader(line, inheader):
97 yield chunk(cur)
97 yield chunk(cur)
98 cur = []
98 cur = []
99 inheader = True
99 inheader = True
100 if inheader and not isheader(line, inheader):
100 if inheader and not isheader(line, inheader):
101 inheader = False
101 inheader = False
102
102
103 cur.append(line)
103 cur.append(line)
104
104
105 if cur:
105 if cur:
106 yield chunk(cur)
106 yield chunk(cur)
107
107
108 def remainder(cur):
108 def remainder(cur):
109 yield chunk(cur)
109 yield chunk(cur)
110
110
111 class fiter(object):
111 class fiter(object):
112 def __init__(self, fp):
112 def __init__(self, fp):
113 self.fp = fp
113 self.fp = fp
114
114
115 def __iter__(self):
115 def __iter__(self):
116 return self
116 return self
117
117
118 def next(self):
118 def next(self):
119 l = self.fp.readline()
119 l = self.fp.readline()
120 if not l:
120 if not l:
121 raise StopIteration
121 raise StopIteration
122 return l
122 return l
123
123
124 inheader = False
124 inheader = False
125 cur = []
125 cur = []
126
126
127 mimeheaders = ['content-type']
127 mimeheaders = ['content-type']
128
128
129 if not hasattr(stream, 'next'):
129 if not hasattr(stream, 'next'):
130 # http responses, for example, have readline but not next
130 # http responses, for example, have readline but not next
131 stream = fiter(stream)
131 stream = fiter(stream)
132
132
133 for line in stream:
133 for line in stream:
134 cur.append(line)
134 cur.append(line)
135 if line.startswith('# HG changeset patch'):
135 if line.startswith('# HG changeset patch'):
136 return hgsplit(stream, cur)
136 return hgsplit(stream, cur)
137 elif line.startswith('From '):
137 elif line.startswith('From '):
138 return mboxsplit(stream, cur)
138 return mboxsplit(stream, cur)
139 elif isheader(line, inheader):
139 elif isheader(line, inheader):
140 inheader = True
140 inheader = True
141 if line.split(':', 1)[0].lower() in mimeheaders:
141 if line.split(':', 1)[0].lower() in mimeheaders:
142 # let email parser handle this
142 # let email parser handle this
143 return mimesplit(stream, cur)
143 return mimesplit(stream, cur)
144 elif line.startswith('--- ') and inheader:
144 elif line.startswith('--- ') and inheader:
145 # No evil headers seen by diff start, split by hand
145 # No evil headers seen by diff start, split by hand
146 return headersplit(stream, cur)
146 return headersplit(stream, cur)
147 # Not enough info, keep reading
147 # Not enough info, keep reading
148
148
149 # if we are here, we have a very plain patch
149 # if we are here, we have a very plain patch
150 return remainder(cur)
150 return remainder(cur)
151
151
152 def extract(ui, fileobj):
152 def extract(ui, fileobj):
153 '''extract patch from data read from fileobj.
153 '''extract patch from data read from fileobj.
154
154
155 patch can be a normal patch or contained in an email message.
155 patch can be a normal patch or contained in an email message.
156
156
157 return tuple (filename, message, user, date, branch, node, p1, p2).
157 return tuple (filename, message, user, date, branch, node, p1, p2).
158 Any item in the returned tuple can be None. If filename is None,
158 Any item in the returned tuple can be None. If filename is None,
159 fileobj did not contain a patch. Caller must unlink filename when done.'''
159 fileobj did not contain a patch. Caller must unlink filename when done.'''
160
160
161 # attempt to detect the start of a patch
161 # attempt to detect the start of a patch
162 # (this heuristic is borrowed from quilt)
162 # (this heuristic is borrowed from quilt)
163 diffre = re.compile(r'^(?:Index:[ \t]|diff[ \t]|RCS file: |'
163 diffre = re.compile(r'^(?:Index:[ \t]|diff[ \t]|RCS file: |'
164 r'retrieving revision [0-9]+(\.[0-9]+)*$|'
164 r'retrieving revision [0-9]+(\.[0-9]+)*$|'
165 r'---[ \t].*?^\+\+\+[ \t]|'
165 r'---[ \t].*?^\+\+\+[ \t]|'
166 r'\*\*\*[ \t].*?^---[ \t])', re.MULTILINE|re.DOTALL)
166 r'\*\*\*[ \t].*?^---[ \t])', re.MULTILINE|re.DOTALL)
167
167
168 fd, tmpname = tempfile.mkstemp(prefix='hg-patch-')
168 fd, tmpname = tempfile.mkstemp(prefix='hg-patch-')
169 tmpfp = os.fdopen(fd, 'w')
169 tmpfp = os.fdopen(fd, 'w')
170 try:
170 try:
171 msg = email.Parser.Parser().parse(fileobj)
171 msg = email.Parser.Parser().parse(fileobj)
172
172
173 subject = msg['Subject']
173 subject = msg['Subject']
174 user = msg['From']
174 user = msg['From']
175 if not subject and not user:
175 if not subject and not user:
176 # Not an email, restore parsed headers if any
176 # Not an email, restore parsed headers if any
177 subject = '\n'.join(': '.join(h) for h in msg.items()) + '\n'
177 subject = '\n'.join(': '.join(h) for h in msg.items()) + '\n'
178
178
179 gitsendmail = 'git-send-email' in msg.get('X-Mailer', '')
179 gitsendmail = 'git-send-email' in msg.get('X-Mailer', '')
180 # should try to parse msg['Date']
180 # should try to parse msg['Date']
181 date = None
181 date = None
182 nodeid = None
182 nodeid = None
183 branch = None
183 branch = None
184 parents = []
184 parents = []
185
185
186 if subject:
186 if subject:
187 if subject.startswith('[PATCH'):
187 if subject.startswith('[PATCH'):
188 pend = subject.find(']')
188 pend = subject.find(']')
189 if pend >= 0:
189 if pend >= 0:
190 subject = subject[pend + 1:].lstrip()
190 subject = subject[pend + 1:].lstrip()
191 subject = subject.replace('\n\t', ' ')
191 subject = subject.replace('\n\t', ' ')
192 ui.debug('Subject: %s\n' % subject)
192 ui.debug('Subject: %s\n' % subject)
193 if user:
193 if user:
194 ui.debug('From: %s\n' % user)
194 ui.debug('From: %s\n' % user)
195 diffs_seen = 0
195 diffs_seen = 0
196 ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
196 ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
197 message = ''
197 message = ''
198 for part in msg.walk():
198 for part in msg.walk():
199 content_type = part.get_content_type()
199 content_type = part.get_content_type()
200 ui.debug('Content-Type: %s\n' % content_type)
200 ui.debug('Content-Type: %s\n' % content_type)
201 if content_type not in ok_types:
201 if content_type not in ok_types:
202 continue
202 continue
203 payload = part.get_payload(decode=True)
203 payload = part.get_payload(decode=True)
204 m = diffre.search(payload)
204 m = diffre.search(payload)
205 if m:
205 if m:
206 hgpatch = False
206 hgpatch = False
207 hgpatchheader = False
207 hgpatchheader = False
208 ignoretext = False
208 ignoretext = False
209
209
210 ui.debug('found patch at byte %d\n' % m.start(0))
210 ui.debug('found patch at byte %d\n' % m.start(0))
211 diffs_seen += 1
211 diffs_seen += 1
212 cfp = cStringIO.StringIO()
212 cfp = cStringIO.StringIO()
213 for line in payload[:m.start(0)].splitlines():
213 for line in payload[:m.start(0)].splitlines():
214 if line.startswith('# HG changeset patch') and not hgpatch:
214 if line.startswith('# HG changeset patch') and not hgpatch:
215 ui.debug('patch generated by hg export\n')
215 ui.debug('patch generated by hg export\n')
216 hgpatch = True
216 hgpatch = True
217 hgpatchheader = True
217 hgpatchheader = True
218 # drop earlier commit message content
218 # drop earlier commit message content
219 cfp.seek(0)
219 cfp.seek(0)
220 cfp.truncate()
220 cfp.truncate()
221 subject = None
221 subject = None
222 elif hgpatchheader:
222 elif hgpatchheader:
223 if line.startswith('# User '):
223 if line.startswith('# User '):
224 user = line[7:]
224 user = line[7:]
225 ui.debug('From: %s\n' % user)
225 ui.debug('From: %s\n' % user)
226 elif line.startswith("# Date "):
226 elif line.startswith("# Date "):
227 date = line[7:]
227 date = line[7:]
228 elif line.startswith("# Branch "):
228 elif line.startswith("# Branch "):
229 branch = line[9:]
229 branch = line[9:]
230 elif line.startswith("# Node ID "):
230 elif line.startswith("# Node ID "):
231 nodeid = line[10:]
231 nodeid = line[10:]
232 elif line.startswith("# Parent "):
232 elif line.startswith("# Parent "):
233 parents.append(line[10:])
233 parents.append(line[10:])
234 elif not line.startswith("# "):
234 elif not line.startswith("# "):
235 hgpatchheader = False
235 hgpatchheader = False
236 elif line == '---' and gitsendmail:
236 elif line == '---' and gitsendmail:
237 ignoretext = True
237 ignoretext = True
238 if not hgpatchheader and not ignoretext:
238 if not hgpatchheader and not ignoretext:
239 cfp.write(line)
239 cfp.write(line)
240 cfp.write('\n')
240 cfp.write('\n')
241 message = cfp.getvalue()
241 message = cfp.getvalue()
242 if tmpfp:
242 if tmpfp:
243 tmpfp.write(payload)
243 tmpfp.write(payload)
244 if not payload.endswith('\n'):
244 if not payload.endswith('\n'):
245 tmpfp.write('\n')
245 tmpfp.write('\n')
246 elif not diffs_seen and message and content_type == 'text/plain':
246 elif not diffs_seen and message and content_type == 'text/plain':
247 message += '\n' + payload
247 message += '\n' + payload
248 except:
248 except:
249 tmpfp.close()
249 tmpfp.close()
250 os.unlink(tmpname)
250 os.unlink(tmpname)
251 raise
251 raise
252
252
253 if subject and not message.startswith(subject):
253 if subject and not message.startswith(subject):
254 message = '%s\n%s' % (subject, message)
254 message = '%s\n%s' % (subject, message)
255 tmpfp.close()
255 tmpfp.close()
256 if not diffs_seen:
256 if not diffs_seen:
257 os.unlink(tmpname)
257 os.unlink(tmpname)
258 return None, message, user, date, branch, None, None, None
258 return None, message, user, date, branch, None, None, None
259 p1 = parents and parents.pop(0) or None
259 p1 = parents and parents.pop(0) or None
260 p2 = parents and parents.pop(0) or None
260 p2 = parents and parents.pop(0) or None
261 return tmpname, message, user, date, branch, nodeid, p1, p2
261 return tmpname, message, user, date, branch, nodeid, p1, p2
262
262
263 class patchmeta(object):
263 class patchmeta(object):
264 """Patched file metadata
264 """Patched file metadata
265
265
266 'op' is the performed operation within ADD, DELETE, RENAME, MODIFY
266 'op' is the performed operation within ADD, DELETE, RENAME, MODIFY
267 or COPY. 'path' is patched file path. 'oldpath' is set to the
267 or COPY. 'path' is patched file path. 'oldpath' is set to the
268 origin file when 'op' is either COPY or RENAME, None otherwise. If
268 origin file when 'op' is either COPY or RENAME, None otherwise. If
269 file mode is changed, 'mode' is a tuple (islink, isexec) where
269 file mode is changed, 'mode' is a tuple (islink, isexec) where
270 'islink' is True if the file is a symlink and 'isexec' is True if
270 'islink' is True if the file is a symlink and 'isexec' is True if
271 the file is executable. Otherwise, 'mode' is None.
271 the file is executable. Otherwise, 'mode' is None.
272 """
272 """
273 def __init__(self, path):
273 def __init__(self, path):
274 self.path = path
274 self.path = path
275 self.oldpath = None
275 self.oldpath = None
276 self.mode = None
276 self.mode = None
277 self.op = 'MODIFY'
277 self.op = 'MODIFY'
278 self.binary = False
278 self.binary = False
279
279
280 def setmode(self, mode):
280 def setmode(self, mode):
281 islink = mode & 020000
281 islink = mode & 020000
282 isexec = mode & 0100
282 isexec = mode & 0100
283 self.mode = (islink, isexec)
283 self.mode = (islink, isexec)
284
284
285 def copy(self):
285 def copy(self):
286 other = patchmeta(self.path)
286 other = patchmeta(self.path)
287 other.oldpath = self.oldpath
287 other.oldpath = self.oldpath
288 other.mode = self.mode
288 other.mode = self.mode
289 other.op = self.op
289 other.op = self.op
290 other.binary = self.binary
290 other.binary = self.binary
291 return other
291 return other
292
292
293 def __repr__(self):
293 def __repr__(self):
294 return "<patchmeta %s %r>" % (self.op, self.path)
294 return "<patchmeta %s %r>" % (self.op, self.path)
295
295
296 def readgitpatch(lr):
296 def readgitpatch(lr):
297 """extract git-style metadata about patches from <patchname>"""
297 """extract git-style metadata about patches from <patchname>"""
298
298
299 # Filter patch for git information
299 # Filter patch for git information
300 gp = None
300 gp = None
301 gitpatches = []
301 gitpatches = []
302 for line in lr:
302 for line in lr:
303 line = line.rstrip(' \r\n')
303 line = line.rstrip(' \r\n')
304 if line.startswith('diff --git'):
304 if line.startswith('diff --git'):
305 m = gitre.match(line)
305 m = gitre.match(line)
306 if m:
306 if m:
307 if gp:
307 if gp:
308 gitpatches.append(gp)
308 gitpatches.append(gp)
309 dst = m.group(2)
309 dst = m.group(2)
310 gp = patchmeta(dst)
310 gp = patchmeta(dst)
311 elif gp:
311 elif gp:
312 if line.startswith('--- '):
312 if line.startswith('--- '):
313 gitpatches.append(gp)
313 gitpatches.append(gp)
314 gp = None
314 gp = None
315 continue
315 continue
316 if line.startswith('rename from '):
316 if line.startswith('rename from '):
317 gp.op = 'RENAME'
317 gp.op = 'RENAME'
318 gp.oldpath = line[12:]
318 gp.oldpath = line[12:]
319 elif line.startswith('rename to '):
319 elif line.startswith('rename to '):
320 gp.path = line[10:]
320 gp.path = line[10:]
321 elif line.startswith('copy from '):
321 elif line.startswith('copy from '):
322 gp.op = 'COPY'
322 gp.op = 'COPY'
323 gp.oldpath = line[10:]
323 gp.oldpath = line[10:]
324 elif line.startswith('copy to '):
324 elif line.startswith('copy to '):
325 gp.path = line[8:]
325 gp.path = line[8:]
326 elif line.startswith('deleted file'):
326 elif line.startswith('deleted file'):
327 gp.op = 'DELETE'
327 gp.op = 'DELETE'
328 elif line.startswith('new file mode '):
328 elif line.startswith('new file mode '):
329 gp.op = 'ADD'
329 gp.op = 'ADD'
330 gp.setmode(int(line[-6:], 8))
330 gp.setmode(int(line[-6:], 8))
331 elif line.startswith('new mode '):
331 elif line.startswith('new mode '):
332 gp.setmode(int(line[-6:], 8))
332 gp.setmode(int(line[-6:], 8))
333 elif line.startswith('GIT binary patch'):
333 elif line.startswith('GIT binary patch'):
334 gp.binary = True
334 gp.binary = True
335 if gp:
335 if gp:
336 gitpatches.append(gp)
336 gitpatches.append(gp)
337
337
338 return gitpatches
338 return gitpatches
339
339
340 class linereader(object):
340 class linereader(object):
341 # simple class to allow pushing lines back into the input stream
341 # simple class to allow pushing lines back into the input stream
342 def __init__(self, fp):
342 def __init__(self, fp):
343 self.fp = fp
343 self.fp = fp
344 self.buf = []
344 self.buf = []
345
345
346 def push(self, line):
346 def push(self, line):
347 if line is not None:
347 if line is not None:
348 self.buf.append(line)
348 self.buf.append(line)
349
349
350 def readline(self):
350 def readline(self):
351 if self.buf:
351 if self.buf:
352 l = self.buf[0]
352 l = self.buf[0]
353 del self.buf[0]
353 del self.buf[0]
354 return l
354 return l
355 return self.fp.readline()
355 return self.fp.readline()
356
356
357 def __iter__(self):
357 def __iter__(self):
358 while True:
358 while True:
359 l = self.readline()
359 l = self.readline()
360 if not l:
360 if not l:
361 break
361 break
362 yield l
362 yield l
363
363
364 class abstractbackend(object):
364 class abstractbackend(object):
365 def __init__(self, ui):
365 def __init__(self, ui):
366 self.ui = ui
366 self.ui = ui
367
367
368 def getfile(self, fname):
368 def getfile(self, fname):
369 """Return target file data and flags as a (data, (islink,
369 """Return target file data and flags as a (data, (islink,
370 isexec)) tuple.
370 isexec)) tuple.
371 """
371 """
372 raise NotImplementedError
372 raise NotImplementedError
373
373
374 def setfile(self, fname, data, mode, copysource):
374 def setfile(self, fname, data, mode, copysource):
375 """Write data to target file fname and set its mode. mode is a
375 """Write data to target file fname and set its mode. mode is a
376 (islink, isexec) tuple. If data is None, the file content should
376 (islink, isexec) tuple. If data is None, the file content should
377 be left unchanged. If the file is modified after being copied,
377 be left unchanged. If the file is modified after being copied,
378 copysource is set to the original file name.
378 copysource is set to the original file name.
379 """
379 """
380 raise NotImplementedError
380 raise NotImplementedError
381
381
382 def unlink(self, fname):
382 def unlink(self, fname):
383 """Unlink target file."""
383 """Unlink target file."""
384 raise NotImplementedError
384 raise NotImplementedError
385
385
386 def writerej(self, fname, failed, total, lines):
386 def writerej(self, fname, failed, total, lines):
387 """Write rejected lines for fname. total is the number of hunks
387 """Write rejected lines for fname. total is the number of hunks
388 which failed to apply and total the total number of hunks for this
388 which failed to apply and total the total number of hunks for this
389 files.
389 files.
390 """
390 """
391 pass
391 pass
392
392
393 def exists(self, fname):
393 def exists(self, fname):
394 raise NotImplementedError
394 raise NotImplementedError
395
395
396 class fsbackend(abstractbackend):
396 class fsbackend(abstractbackend):
397 def __init__(self, ui, basedir):
397 def __init__(self, ui, basedir):
398 super(fsbackend, self).__init__(ui)
398 super(fsbackend, self).__init__(ui)
399 self.opener = scmutil.opener(basedir)
399 self.opener = scmutil.opener(basedir)
400
400
401 def _join(self, f):
401 def _join(self, f):
402 return os.path.join(self.opener.base, f)
402 return os.path.join(self.opener.base, f)
403
403
404 def getfile(self, fname):
404 def getfile(self, fname):
405 path = self._join(fname)
405 path = self._join(fname)
406 if os.path.islink(path):
406 if os.path.islink(path):
407 return (os.readlink(path), (True, False))
407 return (os.readlink(path), (True, False))
408 isexec = False
408 isexec = False
409 try:
409 try:
410 isexec = os.lstat(path).st_mode & 0100 != 0
410 isexec = os.lstat(path).st_mode & 0100 != 0
411 except OSError, e:
411 except OSError, e:
412 if e.errno != errno.ENOENT:
412 if e.errno != errno.ENOENT:
413 raise
413 raise
414 return (self.opener.read(fname), (False, isexec))
414 return (self.opener.read(fname), (False, isexec))
415
415
416 def setfile(self, fname, data, mode, copysource):
416 def setfile(self, fname, data, mode, copysource):
417 islink, isexec = mode
417 islink, isexec = mode
418 if data is None:
418 if data is None:
419 util.setflags(self._join(fname), islink, isexec)
419 util.setflags(self._join(fname), islink, isexec)
420 return
420 return
421 if islink:
421 if islink:
422 self.opener.symlink(data, fname)
422 self.opener.symlink(data, fname)
423 else:
423 else:
424 self.opener.write(fname, data)
424 self.opener.write(fname, data)
425 if isexec:
425 if isexec:
426 util.setflags(self._join(fname), False, True)
426 util.setflags(self._join(fname), False, True)
427
427
428 def unlink(self, fname):
428 def unlink(self, fname):
429 try:
429 try:
430 util.unlinkpath(self._join(fname))
430 util.unlinkpath(self._join(fname))
431 except OSError, inst:
431 except OSError, inst:
432 if inst.errno != errno.ENOENT:
432 if inst.errno != errno.ENOENT:
433 raise
433 raise
434
434
435 def writerej(self, fname, failed, total, lines):
435 def writerej(self, fname, failed, total, lines):
436 fname = fname + ".rej"
436 fname = fname + ".rej"
437 self.ui.warn(
437 self.ui.warn(
438 _("%d out of %d hunks FAILED -- saving rejects to file %s\n") %
438 _("%d out of %d hunks FAILED -- saving rejects to file %s\n") %
439 (failed, total, fname))
439 (failed, total, fname))
440 fp = self.opener(fname, 'w')
440 fp = self.opener(fname, 'w')
441 fp.writelines(lines)
441 fp.writelines(lines)
442 fp.close()
442 fp.close()
443
443
444 def exists(self, fname):
444 def exists(self, fname):
445 return os.path.lexists(self._join(fname))
445 return os.path.lexists(self._join(fname))
446
446
447 class workingbackend(fsbackend):
447 class workingbackend(fsbackend):
448 def __init__(self, ui, repo, similarity):
448 def __init__(self, ui, repo, similarity):
449 super(workingbackend, self).__init__(ui, repo.root)
449 super(workingbackend, self).__init__(ui, repo.root)
450 self.repo = repo
450 self.repo = repo
451 self.similarity = similarity
451 self.similarity = similarity
452 self.removed = set()
452 self.removed = set()
453 self.changed = set()
453 self.changed = set()
454 self.copied = []
454 self.copied = []
455
455
456 def _checkknown(self, fname):
456 def _checkknown(self, fname):
457 if self.repo.dirstate[fname] == '?' and self.exists(fname):
457 if self.repo.dirstate[fname] == '?' and self.exists(fname):
458 raise PatchError(_('cannot patch %s: file is not tracked') % fname)
458 raise PatchError(_('cannot patch %s: file is not tracked') % fname)
459
459
460 def setfile(self, fname, data, mode, copysource):
460 def setfile(self, fname, data, mode, copysource):
461 self._checkknown(fname)
461 self._checkknown(fname)
462 super(workingbackend, self).setfile(fname, data, mode, copysource)
462 super(workingbackend, self).setfile(fname, data, mode, copysource)
463 if copysource is not None:
463 if copysource is not None:
464 self.copied.append((copysource, fname))
464 self.copied.append((copysource, fname))
465 self.changed.add(fname)
465 self.changed.add(fname)
466
466
467 def unlink(self, fname):
467 def unlink(self, fname):
468 self._checkknown(fname)
468 self._checkknown(fname)
469 super(workingbackend, self).unlink(fname)
469 super(workingbackend, self).unlink(fname)
470 self.removed.add(fname)
470 self.removed.add(fname)
471 self.changed.add(fname)
471 self.changed.add(fname)
472
472
473 def close(self):
473 def close(self):
474 wctx = self.repo[None]
474 wctx = self.repo[None]
475 addremoved = set(self.changed)
475 addremoved = set(self.changed)
476 for src, dst in self.copied:
476 for src, dst in self.copied:
477 scmutil.dirstatecopy(self.ui, self.repo, wctx, src, dst)
477 scmutil.dirstatecopy(self.ui, self.repo, wctx, src, dst)
478 addremoved.discard(src)
478 addremoved.discard(src)
479 if (not self.similarity) and self.removed:
479 if (not self.similarity) and self.removed:
480 wctx.forget(sorted(self.removed))
480 wctx.forget(sorted(self.removed))
481 if addremoved:
481 if addremoved:
482 cwd = self.repo.getcwd()
482 cwd = self.repo.getcwd()
483 if cwd:
483 if cwd:
484 addremoved = [util.pathto(self.repo.root, cwd, f)
484 addremoved = [util.pathto(self.repo.root, cwd, f)
485 for f in addremoved]
485 for f in addremoved]
486 scmutil.addremove(self.repo, addremoved, similarity=self.similarity)
486 scmutil.addremove(self.repo, addremoved, similarity=self.similarity)
487 return sorted(self.changed)
487 return sorted(self.changed)
488
488
489 class filestore(object):
489 class filestore(object):
490 def __init__(self, maxsize=None):
490 def __init__(self, maxsize=None):
491 self.opener = None
491 self.opener = None
492 self.files = {}
492 self.files = {}
493 self.created = 0
493 self.created = 0
494 self.maxsize = maxsize
494 self.maxsize = maxsize
495 if self.maxsize is None:
495 if self.maxsize is None:
496 self.maxsize = 4*(2**20)
496 self.maxsize = 4*(2**20)
497 self.size = 0
497 self.size = 0
498 self.data = {}
498 self.data = {}
499
499
500 def setfile(self, fname, data, mode, copied=None):
500 def setfile(self, fname, data, mode, copied=None):
501 if self.maxsize < 0 or (len(data) + self.size) <= self.maxsize:
501 if self.maxsize < 0 or (len(data) + self.size) <= self.maxsize:
502 self.data[fname] = (data, mode, copied)
502 self.data[fname] = (data, mode, copied)
503 self.size += len(data)
503 self.size += len(data)
504 else:
504 else:
505 if self.opener is None:
505 if self.opener is None:
506 root = tempfile.mkdtemp(prefix='hg-patch-')
506 root = tempfile.mkdtemp(prefix='hg-patch-')
507 self.opener = scmutil.opener(root)
507 self.opener = scmutil.opener(root)
508 # Avoid filename issues with these simple names
508 # Avoid filename issues with these simple names
509 fn = str(self.created)
509 fn = str(self.created)
510 self.opener.write(fn, data)
510 self.opener.write(fn, data)
511 self.created += 1
511 self.created += 1
512 self.files[fname] = (fn, mode, copied)
512 self.files[fname] = (fn, mode, copied)
513
513
514 def getfile(self, fname):
514 def getfile(self, fname):
515 if fname in self.data:
515 if fname in self.data:
516 return self.data[fname]
516 return self.data[fname]
517 if not self.opener or fname not in self.files:
517 if not self.opener or fname not in self.files:
518 raise IOError()
518 raise IOError()
519 fn, mode, copied = self.files[fname]
519 fn, mode, copied = self.files[fname]
520 return self.opener.read(fn), mode, copied
520 return self.opener.read(fn), mode, copied
521
521
522 def close(self):
522 def close(self):
523 if self.opener:
523 if self.opener:
524 shutil.rmtree(self.opener.base)
524 shutil.rmtree(self.opener.base)
525
525
526 class repobackend(abstractbackend):
526 class repobackend(abstractbackend):
527 def __init__(self, ui, repo, ctx, store):
527 def __init__(self, ui, repo, ctx, store):
528 super(repobackend, self).__init__(ui)
528 super(repobackend, self).__init__(ui)
529 self.repo = repo
529 self.repo = repo
530 self.ctx = ctx
530 self.ctx = ctx
531 self.store = store
531 self.store = store
532 self.changed = set()
532 self.changed = set()
533 self.removed = set()
533 self.removed = set()
534 self.copied = {}
534 self.copied = {}
535
535
536 def _checkknown(self, fname):
536 def _checkknown(self, fname):
537 if fname not in self.ctx:
537 if fname not in self.ctx:
538 raise PatchError(_('cannot patch %s: file is not tracked') % fname)
538 raise PatchError(_('cannot patch %s: file is not tracked') % fname)
539
539
540 def getfile(self, fname):
540 def getfile(self, fname):
541 try:
541 try:
542 fctx = self.ctx[fname]
542 fctx = self.ctx[fname]
543 except error.LookupError:
543 except error.LookupError:
544 raise IOError()
544 raise IOError()
545 flags = fctx.flags()
545 flags = fctx.flags()
546 return fctx.data(), ('l' in flags, 'x' in flags)
546 return fctx.data(), ('l' in flags, 'x' in flags)
547
547
548 def setfile(self, fname, data, mode, copysource):
548 def setfile(self, fname, data, mode, copysource):
549 if copysource:
549 if copysource:
550 self._checkknown(copysource)
550 self._checkknown(copysource)
551 if data is None:
551 if data is None:
552 data = self.ctx[fname].data()
552 data = self.ctx[fname].data()
553 self.store.setfile(fname, data, mode, copysource)
553 self.store.setfile(fname, data, mode, copysource)
554 self.changed.add(fname)
554 self.changed.add(fname)
555 if copysource:
555 if copysource:
556 self.copied[fname] = copysource
556 self.copied[fname] = copysource
557
557
558 def unlink(self, fname):
558 def unlink(self, fname):
559 self._checkknown(fname)
559 self._checkknown(fname)
560 self.removed.add(fname)
560 self.removed.add(fname)
561
561
562 def exists(self, fname):
562 def exists(self, fname):
563 return fname in self.ctx
563 return fname in self.ctx
564
564
565 def close(self):
565 def close(self):
566 return self.changed | self.removed
566 return self.changed | self.removed
567
567
568 # @@ -start,len +start,len @@ or @@ -start +start @@ if len is 1
568 # @@ -start,len +start,len @@ or @@ -start +start @@ if len is 1
569 unidesc = re.compile('@@ -(\d+)(,(\d+))? \+(\d+)(,(\d+))? @@')
569 unidesc = re.compile('@@ -(\d+)(,(\d+))? \+(\d+)(,(\d+))? @@')
570 contextdesc = re.compile('(---|\*\*\*) (\d+)(,(\d+))? (---|\*\*\*)')
570 contextdesc = re.compile('(---|\*\*\*) (\d+)(,(\d+))? (---|\*\*\*)')
571 eolmodes = ['strict', 'crlf', 'lf', 'auto']
571 eolmodes = ['strict', 'crlf', 'lf', 'auto']
572
572
573 class patchfile(object):
573 class patchfile(object):
574 def __init__(self, ui, gp, backend, store, eolmode='strict'):
574 def __init__(self, ui, gp, backend, store, eolmode='strict'):
575 self.fname = gp.path
575 self.fname = gp.path
576 self.eolmode = eolmode
576 self.eolmode = eolmode
577 self.eol = None
577 self.eol = None
578 self.backend = backend
578 self.backend = backend
579 self.ui = ui
579 self.ui = ui
580 self.lines = []
580 self.lines = []
581 self.exists = False
581 self.exists = False
582 self.missing = True
582 self.missing = True
583 self.mode = gp.mode
583 self.mode = gp.mode
584 self.copysource = gp.oldpath
584 self.copysource = gp.oldpath
585 self.create = gp.op in ('ADD', 'COPY', 'RENAME')
585 self.create = gp.op in ('ADD', 'COPY', 'RENAME')
586 self.remove = gp.op == 'DELETE'
586 self.remove = gp.op == 'DELETE'
587 try:
587 try:
588 if self.copysource is None:
588 if self.copysource is None:
589 data, mode = backend.getfile(self.fname)
589 data, mode = backend.getfile(self.fname)
590 self.exists = True
590 self.exists = True
591 else:
591 else:
592 data, mode = store.getfile(self.copysource)[:2]
592 data, mode = store.getfile(self.copysource)[:2]
593 self.exists = backend.exists(self.fname)
593 self.exists = backend.exists(self.fname)
594 self.missing = False
594 self.missing = False
595 if data:
595 if data:
596 self.lines = data.splitlines(True)
596 self.lines = data.splitlines(True)
597 if self.mode is None:
597 if self.mode is None:
598 self.mode = mode
598 self.mode = mode
599 if self.lines:
599 if self.lines:
600 # Normalize line endings
600 # Normalize line endings
601 if self.lines[0].endswith('\r\n'):
601 if self.lines[0].endswith('\r\n'):
602 self.eol = '\r\n'
602 self.eol = '\r\n'
603 elif self.lines[0].endswith('\n'):
603 elif self.lines[0].endswith('\n'):
604 self.eol = '\n'
604 self.eol = '\n'
605 if eolmode != 'strict':
605 if eolmode != 'strict':
606 nlines = []
606 nlines = []
607 for l in self.lines:
607 for l in self.lines:
608 if l.endswith('\r\n'):
608 if l.endswith('\r\n'):
609 l = l[:-2] + '\n'
609 l = l[:-2] + '\n'
610 nlines.append(l)
610 nlines.append(l)
611 self.lines = nlines
611 self.lines = nlines
612 except IOError:
612 except IOError:
613 if self.create:
613 if self.create:
614 self.missing = False
614 self.missing = False
615 if self.mode is None:
615 if self.mode is None:
616 self.mode = (False, False)
616 self.mode = (False, False)
617 if self.missing:
617 if self.missing:
618 self.ui.warn(_("unable to find '%s' for patching\n") % self.fname)
618 self.ui.warn(_("unable to find '%s' for patching\n") % self.fname)
619
619
620 self.hash = {}
620 self.hash = {}
621 self.dirty = 0
621 self.dirty = 0
622 self.offset = 0
622 self.offset = 0
623 self.skew = 0
623 self.skew = 0
624 self.rej = []
624 self.rej = []
625 self.fileprinted = False
625 self.fileprinted = False
626 self.printfile(False)
626 self.printfile(False)
627 self.hunks = 0
627 self.hunks = 0
628
628
629 def writelines(self, fname, lines, mode):
629 def writelines(self, fname, lines, mode):
630 if self.eolmode == 'auto':
630 if self.eolmode == 'auto':
631 eol = self.eol
631 eol = self.eol
632 elif self.eolmode == 'crlf':
632 elif self.eolmode == 'crlf':
633 eol = '\r\n'
633 eol = '\r\n'
634 else:
634 else:
635 eol = '\n'
635 eol = '\n'
636
636
637 if self.eolmode != 'strict' and eol and eol != '\n':
637 if self.eolmode != 'strict' and eol and eol != '\n':
638 rawlines = []
638 rawlines = []
639 for l in lines:
639 for l in lines:
640 if l and l[-1] == '\n':
640 if l and l[-1] == '\n':
641 l = l[:-1] + eol
641 l = l[:-1] + eol
642 rawlines.append(l)
642 rawlines.append(l)
643 lines = rawlines
643 lines = rawlines
644
644
645 self.backend.setfile(fname, ''.join(lines), mode, self.copysource)
645 self.backend.setfile(fname, ''.join(lines), mode, self.copysource)
646
646
647 def printfile(self, warn):
647 def printfile(self, warn):
648 if self.fileprinted:
648 if self.fileprinted:
649 return
649 return
650 if warn or self.ui.verbose:
650 if warn or self.ui.verbose:
651 self.fileprinted = True
651 self.fileprinted = True
652 s = _("patching file %s\n") % self.fname
652 s = _("patching file %s\n") % self.fname
653 if warn:
653 if warn:
654 self.ui.warn(s)
654 self.ui.warn(s)
655 else:
655 else:
656 self.ui.note(s)
656 self.ui.note(s)
657
657
658
658
659 def findlines(self, l, linenum):
659 def findlines(self, l, linenum):
660 # looks through the hash and finds candidate lines. The
660 # looks through the hash and finds candidate lines. The
661 # result is a list of line numbers sorted based on distance
661 # result is a list of line numbers sorted based on distance
662 # from linenum
662 # from linenum
663
663
664 cand = self.hash.get(l, [])
664 cand = self.hash.get(l, [])
665 if len(cand) > 1:
665 if len(cand) > 1:
666 # resort our list of potentials forward then back.
666 # resort our list of potentials forward then back.
667 cand.sort(key=lambda x: abs(x - linenum))
667 cand.sort(key=lambda x: abs(x - linenum))
668 return cand
668 return cand
669
669
670 def write_rej(self):
670 def write_rej(self):
671 # our rejects are a little different from patch(1). This always
671 # our rejects are a little different from patch(1). This always
672 # creates rejects in the same form as the original patch. A file
672 # creates rejects in the same form as the original patch. A file
673 # header is inserted so that you can run the reject through patch again
673 # header is inserted so that you can run the reject through patch again
674 # without having to type the filename.
674 # without having to type the filename.
675 if not self.rej:
675 if not self.rej:
676 return
676 return
677 base = os.path.basename(self.fname)
677 base = os.path.basename(self.fname)
678 lines = ["--- %s\n+++ %s\n" % (base, base)]
678 lines = ["--- %s\n+++ %s\n" % (base, base)]
679 for x in self.rej:
679 for x in self.rej:
680 for l in x.hunk:
680 for l in x.hunk:
681 lines.append(l)
681 lines.append(l)
682 if l[-1] != '\n':
682 if l[-1] != '\n':
683 lines.append("\n\ No newline at end of file\n")
683 lines.append("\n\ No newline at end of file\n")
684 self.backend.writerej(self.fname, len(self.rej), self.hunks, lines)
684 self.backend.writerej(self.fname, len(self.rej), self.hunks, lines)
685
685
686 def apply(self, h):
686 def apply(self, h):
687 if not h.complete():
687 if not h.complete():
688 raise PatchError(_("bad hunk #%d %s (%d %d %d %d)") %
688 raise PatchError(_("bad hunk #%d %s (%d %d %d %d)") %
689 (h.number, h.desc, len(h.a), h.lena, len(h.b),
689 (h.number, h.desc, len(h.a), h.lena, len(h.b),
690 h.lenb))
690 h.lenb))
691
691
692 self.hunks += 1
692 self.hunks += 1
693
693
694 if self.missing:
694 if self.missing:
695 self.rej.append(h)
695 self.rej.append(h)
696 return -1
696 return -1
697
697
698 if self.exists and self.create:
698 if self.exists and self.create:
699 if self.copysource:
699 if self.copysource:
700 self.ui.warn(_("cannot create %s: destination already "
700 self.ui.warn(_("cannot create %s: destination already "
701 "exists\n" % self.fname))
701 "exists\n" % self.fname))
702 else:
702 else:
703 self.ui.warn(_("file %s already exists\n") % self.fname)
703 self.ui.warn(_("file %s already exists\n") % self.fname)
704 self.rej.append(h)
704 self.rej.append(h)
705 return -1
705 return -1
706
706
707 if isinstance(h, binhunk):
707 if isinstance(h, binhunk):
708 if self.remove:
708 if self.remove:
709 self.backend.unlink(self.fname)
709 self.backend.unlink(self.fname)
710 else:
710 else:
711 self.lines[:] = h.new()
711 self.lines[:] = h.new()
712 self.offset += len(h.new())
712 self.offset += len(h.new())
713 self.dirty = True
713 self.dirty = True
714 return 0
714 return 0
715
715
716 horig = h
716 horig = h
717 if (self.eolmode in ('crlf', 'lf')
717 if (self.eolmode in ('crlf', 'lf')
718 or self.eolmode == 'auto' and self.eol):
718 or self.eolmode == 'auto' and self.eol):
719 # If new eols are going to be normalized, then normalize
719 # If new eols are going to be normalized, then normalize
720 # hunk data before patching. Otherwise, preserve input
720 # hunk data before patching. Otherwise, preserve input
721 # line-endings.
721 # line-endings.
722 h = h.getnormalized()
722 h = h.getnormalized()
723
723
724 # fast case first, no offsets, no fuzz
724 # fast case first, no offsets, no fuzz
725 old = h.old()
725 old = h.old()
726 # patch starts counting at 1 unless we are adding the file
726 # patch starts counting at 1 unless we are adding the file
727 if h.starta == 0:
727 if h.starta == 0:
728 start = 0
728 start = 0
729 else:
729 else:
730 start = h.starta + self.offset - 1
730 start = h.starta + self.offset - 1
731 orig_start = start
731 orig_start = start
732 # if there's skew we want to emit the "(offset %d lines)" even
732 # if there's skew we want to emit the "(offset %d lines)" even
733 # when the hunk cleanly applies at start + skew, so skip the
733 # when the hunk cleanly applies at start + skew, so skip the
734 # fast case code
734 # fast case code
735 if self.skew == 0 and diffhelpers.testhunk(old, self.lines, start) == 0:
735 if self.skew == 0 and diffhelpers.testhunk(old, self.lines, start) == 0:
736 if self.remove:
736 if self.remove:
737 self.backend.unlink(self.fname)
737 self.backend.unlink(self.fname)
738 else:
738 else:
739 self.lines[start : start + h.lena] = h.new()
739 self.lines[start : start + h.lena] = h.new()
740 self.offset += h.lenb - h.lena
740 self.offset += h.lenb - h.lena
741 self.dirty = True
741 self.dirty = True
742 return 0
742 return 0
743
743
744 # ok, we couldn't match the hunk. Lets look for offsets and fuzz it
744 # ok, we couldn't match the hunk. Lets look for offsets and fuzz it
745 self.hash = {}
745 self.hash = {}
746 for x, s in enumerate(self.lines):
746 for x, s in enumerate(self.lines):
747 self.hash.setdefault(s, []).append(x)
747 self.hash.setdefault(s, []).append(x)
748 if h.hunk[-1][0] != ' ':
748 if h.hunk[-1][0] != ' ':
749 # if the hunk tried to put something at the bottom of the file
749 # if the hunk tried to put something at the bottom of the file
750 # override the start line and use eof here
750 # override the start line and use eof here
751 search_start = len(self.lines)
751 search_start = len(self.lines)
752 else:
752 else:
753 search_start = orig_start + self.skew
753 search_start = orig_start + self.skew
754
754
755 for fuzzlen in xrange(3):
755 for fuzzlen in xrange(3):
756 for toponly in [True, False]:
756 for toponly in [True, False]:
757 old = h.old(fuzzlen, toponly)
757 old = h.old(fuzzlen, toponly)
758
758
759 cand = self.findlines(old[0][1:], search_start)
759 cand = self.findlines(old[0][1:], search_start)
760 for l in cand:
760 for l in cand:
761 if diffhelpers.testhunk(old, self.lines, l) == 0:
761 if diffhelpers.testhunk(old, self.lines, l) == 0:
762 newlines = h.new(fuzzlen, toponly)
762 newlines = h.new(fuzzlen, toponly)
763 self.lines[l : l + len(old)] = newlines
763 self.lines[l : l + len(old)] = newlines
764 self.offset += len(newlines) - len(old)
764 self.offset += len(newlines) - len(old)
765 self.skew = l - orig_start
765 self.skew = l - orig_start
766 self.dirty = True
766 self.dirty = True
767 offset = l - orig_start - fuzzlen
767 offset = l - orig_start - fuzzlen
768 if fuzzlen:
768 if fuzzlen:
769 msg = _("Hunk #%d succeeded at %d "
769 msg = _("Hunk #%d succeeded at %d "
770 "with fuzz %d "
770 "with fuzz %d "
771 "(offset %d lines).\n")
771 "(offset %d lines).\n")
772 self.printfile(True)
772 self.printfile(True)
773 self.ui.warn(msg %
773 self.ui.warn(msg %
774 (h.number, l + 1, fuzzlen, offset))
774 (h.number, l + 1, fuzzlen, offset))
775 else:
775 else:
776 msg = _("Hunk #%d succeeded at %d "
776 msg = _("Hunk #%d succeeded at %d "
777 "(offset %d lines).\n")
777 "(offset %d lines).\n")
778 self.ui.note(msg % (h.number, l + 1, offset))
778 self.ui.note(msg % (h.number, l + 1, offset))
779 return fuzzlen
779 return fuzzlen
780 self.printfile(True)
780 self.printfile(True)
781 self.ui.warn(_("Hunk #%d FAILED at %d\n") % (h.number, orig_start))
781 self.ui.warn(_("Hunk #%d FAILED at %d\n") % (h.number, orig_start))
782 self.rej.append(horig)
782 self.rej.append(horig)
783 return -1
783 return -1
784
784
785 def close(self):
785 def close(self):
786 if self.dirty:
786 if self.dirty:
787 self.writelines(self.fname, self.lines, self.mode)
787 self.writelines(self.fname, self.lines, self.mode)
788 self.write_rej()
788 self.write_rej()
789 return len(self.rej)
789 return len(self.rej)
790
790
791 class hunk(object):
791 class hunk(object):
792 def __init__(self, desc, num, lr, context):
792 def __init__(self, desc, num, lr, context):
793 self.number = num
793 self.number = num
794 self.desc = desc
794 self.desc = desc
795 self.hunk = [desc]
795 self.hunk = [desc]
796 self.a = []
796 self.a = []
797 self.b = []
797 self.b = []
798 self.starta = self.lena = None
798 self.starta = self.lena = None
799 self.startb = self.lenb = None
799 self.startb = self.lenb = None
800 if lr is not None:
800 if lr is not None:
801 if context:
801 if context:
802 self.read_context_hunk(lr)
802 self.read_context_hunk(lr)
803 else:
803 else:
804 self.read_unified_hunk(lr)
804 self.read_unified_hunk(lr)
805
805
806 def getnormalized(self):
806 def getnormalized(self):
807 """Return a copy with line endings normalized to LF."""
807 """Return a copy with line endings normalized to LF."""
808
808
809 def normalize(lines):
809 def normalize(lines):
810 nlines = []
810 nlines = []
811 for line in lines:
811 for line in lines:
812 if line.endswith('\r\n'):
812 if line.endswith('\r\n'):
813 line = line[:-2] + '\n'
813 line = line[:-2] + '\n'
814 nlines.append(line)
814 nlines.append(line)
815 return nlines
815 return nlines
816
816
817 # Dummy object, it is rebuilt manually
817 # Dummy object, it is rebuilt manually
818 nh = hunk(self.desc, self.number, None, None)
818 nh = hunk(self.desc, self.number, None, None)
819 nh.number = self.number
819 nh.number = self.number
820 nh.desc = self.desc
820 nh.desc = self.desc
821 nh.hunk = self.hunk
821 nh.hunk = self.hunk
822 nh.a = normalize(self.a)
822 nh.a = normalize(self.a)
823 nh.b = normalize(self.b)
823 nh.b = normalize(self.b)
824 nh.starta = self.starta
824 nh.starta = self.starta
825 nh.startb = self.startb
825 nh.startb = self.startb
826 nh.lena = self.lena
826 nh.lena = self.lena
827 nh.lenb = self.lenb
827 nh.lenb = self.lenb
828 return nh
828 return nh
829
829
830 def read_unified_hunk(self, lr):
830 def read_unified_hunk(self, lr):
831 m = unidesc.match(self.desc)
831 m = unidesc.match(self.desc)
832 if not m:
832 if not m:
833 raise PatchError(_("bad hunk #%d") % self.number)
833 raise PatchError(_("bad hunk #%d") % self.number)
834 self.starta, foo, self.lena, self.startb, foo2, self.lenb = m.groups()
834 self.starta, foo, self.lena, self.startb, foo2, self.lenb = m.groups()
835 if self.lena is None:
835 if self.lena is None:
836 self.lena = 1
836 self.lena = 1
837 else:
837 else:
838 self.lena = int(self.lena)
838 self.lena = int(self.lena)
839 if self.lenb is None:
839 if self.lenb is None:
840 self.lenb = 1
840 self.lenb = 1
841 else:
841 else:
842 self.lenb = int(self.lenb)
842 self.lenb = int(self.lenb)
843 self.starta = int(self.starta)
843 self.starta = int(self.starta)
844 self.startb = int(self.startb)
844 self.startb = int(self.startb)
845 diffhelpers.addlines(lr, self.hunk, self.lena, self.lenb, self.a, self.b)
845 diffhelpers.addlines(lr, self.hunk, self.lena, self.lenb, self.a, self.b)
846 # if we hit eof before finishing out the hunk, the last line will
846 # if we hit eof before finishing out the hunk, the last line will
847 # be zero length. Lets try to fix it up.
847 # be zero length. Lets try to fix it up.
848 while len(self.hunk[-1]) == 0:
848 while len(self.hunk[-1]) == 0:
849 del self.hunk[-1]
849 del self.hunk[-1]
850 del self.a[-1]
850 del self.a[-1]
851 del self.b[-1]
851 del self.b[-1]
852 self.lena -= 1
852 self.lena -= 1
853 self.lenb -= 1
853 self.lenb -= 1
854 self._fixnewline(lr)
854 self._fixnewline(lr)
855
855
856 def read_context_hunk(self, lr):
856 def read_context_hunk(self, lr):
857 self.desc = lr.readline()
857 self.desc = lr.readline()
858 m = contextdesc.match(self.desc)
858 m = contextdesc.match(self.desc)
859 if not m:
859 if not m:
860 raise PatchError(_("bad hunk #%d") % self.number)
860 raise PatchError(_("bad hunk #%d") % self.number)
861 foo, self.starta, foo2, aend, foo3 = m.groups()
861 foo, self.starta, foo2, aend, foo3 = m.groups()
862 self.starta = int(self.starta)
862 self.starta = int(self.starta)
863 if aend is None:
863 if aend is None:
864 aend = self.starta
864 aend = self.starta
865 self.lena = int(aend) - self.starta
865 self.lena = int(aend) - self.starta
866 if self.starta:
866 if self.starta:
867 self.lena += 1
867 self.lena += 1
868 for x in xrange(self.lena):
868 for x in xrange(self.lena):
869 l = lr.readline()
869 l = lr.readline()
870 if l.startswith('---'):
870 if l.startswith('---'):
871 # lines addition, old block is empty
871 # lines addition, old block is empty
872 lr.push(l)
872 lr.push(l)
873 break
873 break
874 s = l[2:]
874 s = l[2:]
875 if l.startswith('- ') or l.startswith('! '):
875 if l.startswith('- ') or l.startswith('! '):
876 u = '-' + s
876 u = '-' + s
877 elif l.startswith(' '):
877 elif l.startswith(' '):
878 u = ' ' + s
878 u = ' ' + s
879 else:
879 else:
880 raise PatchError(_("bad hunk #%d old text line %d") %
880 raise PatchError(_("bad hunk #%d old text line %d") %
881 (self.number, x))
881 (self.number, x))
882 self.a.append(u)
882 self.a.append(u)
883 self.hunk.append(u)
883 self.hunk.append(u)
884
884
885 l = lr.readline()
885 l = lr.readline()
886 if l.startswith('\ '):
886 if l.startswith('\ '):
887 s = self.a[-1][:-1]
887 s = self.a[-1][:-1]
888 self.a[-1] = s
888 self.a[-1] = s
889 self.hunk[-1] = s
889 self.hunk[-1] = s
890 l = lr.readline()
890 l = lr.readline()
891 m = contextdesc.match(l)
891 m = contextdesc.match(l)
892 if not m:
892 if not m:
893 raise PatchError(_("bad hunk #%d") % self.number)
893 raise PatchError(_("bad hunk #%d") % self.number)
894 foo, self.startb, foo2, bend, foo3 = m.groups()
894 foo, self.startb, foo2, bend, foo3 = m.groups()
895 self.startb = int(self.startb)
895 self.startb = int(self.startb)
896 if bend is None:
896 if bend is None:
897 bend = self.startb
897 bend = self.startb
898 self.lenb = int(bend) - self.startb
898 self.lenb = int(bend) - self.startb
899 if self.startb:
899 if self.startb:
900 self.lenb += 1
900 self.lenb += 1
901 hunki = 1
901 hunki = 1
902 for x in xrange(self.lenb):
902 for x in xrange(self.lenb):
903 l = lr.readline()
903 l = lr.readline()
904 if l.startswith('\ '):
904 if l.startswith('\ '):
905 # XXX: the only way to hit this is with an invalid line range.
905 # XXX: the only way to hit this is with an invalid line range.
906 # The no-eol marker is not counted in the line range, but I
906 # The no-eol marker is not counted in the line range, but I
907 # guess there are diff(1) out there which behave differently.
907 # guess there are diff(1) out there which behave differently.
908 s = self.b[-1][:-1]
908 s = self.b[-1][:-1]
909 self.b[-1] = s
909 self.b[-1] = s
910 self.hunk[hunki - 1] = s
910 self.hunk[hunki - 1] = s
911 continue
911 continue
912 if not l:
912 if not l:
913 # line deletions, new block is empty and we hit EOF
913 # line deletions, new block is empty and we hit EOF
914 lr.push(l)
914 lr.push(l)
915 break
915 break
916 s = l[2:]
916 s = l[2:]
917 if l.startswith('+ ') or l.startswith('! '):
917 if l.startswith('+ ') or l.startswith('! '):
918 u = '+' + s
918 u = '+' + s
919 elif l.startswith(' '):
919 elif l.startswith(' '):
920 u = ' ' + s
920 u = ' ' + s
921 elif len(self.b) == 0:
921 elif len(self.b) == 0:
922 # line deletions, new block is empty
922 # line deletions, new block is empty
923 lr.push(l)
923 lr.push(l)
924 break
924 break
925 else:
925 else:
926 raise PatchError(_("bad hunk #%d old text line %d") %
926 raise PatchError(_("bad hunk #%d old text line %d") %
927 (self.number, x))
927 (self.number, x))
928 self.b.append(s)
928 self.b.append(s)
929 while True:
929 while True:
930 if hunki >= len(self.hunk):
930 if hunki >= len(self.hunk):
931 h = ""
931 h = ""
932 else:
932 else:
933 h = self.hunk[hunki]
933 h = self.hunk[hunki]
934 hunki += 1
934 hunki += 1
935 if h == u:
935 if h == u:
936 break
936 break
937 elif h.startswith('-'):
937 elif h.startswith('-'):
938 continue
938 continue
939 else:
939 else:
940 self.hunk.insert(hunki - 1, u)
940 self.hunk.insert(hunki - 1, u)
941 break
941 break
942
942
943 if not self.a:
943 if not self.a:
944 # this happens when lines were only added to the hunk
944 # this happens when lines were only added to the hunk
945 for x in self.hunk:
945 for x in self.hunk:
946 if x.startswith('-') or x.startswith(' '):
946 if x.startswith('-') or x.startswith(' '):
947 self.a.append(x)
947 self.a.append(x)
948 if not self.b:
948 if not self.b:
949 # this happens when lines were only deleted from the hunk
949 # this happens when lines were only deleted from the hunk
950 for x in self.hunk:
950 for x in self.hunk:
951 if x.startswith('+') or x.startswith(' '):
951 if x.startswith('+') or x.startswith(' '):
952 self.b.append(x[1:])
952 self.b.append(x[1:])
953 # @@ -start,len +start,len @@
953 # @@ -start,len +start,len @@
954 self.desc = "@@ -%d,%d +%d,%d @@\n" % (self.starta, self.lena,
954 self.desc = "@@ -%d,%d +%d,%d @@\n" % (self.starta, self.lena,
955 self.startb, self.lenb)
955 self.startb, self.lenb)
956 self.hunk[0] = self.desc
956 self.hunk[0] = self.desc
957 self._fixnewline(lr)
957 self._fixnewline(lr)
958
958
959 def _fixnewline(self, lr):
959 def _fixnewline(self, lr):
960 l = lr.readline()
960 l = lr.readline()
961 if l.startswith('\ '):
961 if l.startswith('\ '):
962 diffhelpers.fix_newline(self.hunk, self.a, self.b)
962 diffhelpers.fix_newline(self.hunk, self.a, self.b)
963 else:
963 else:
964 lr.push(l)
964 lr.push(l)
965
965
966 def complete(self):
966 def complete(self):
967 return len(self.a) == self.lena and len(self.b) == self.lenb
967 return len(self.a) == self.lena and len(self.b) == self.lenb
968
968
969 def fuzzit(self, l, fuzz, toponly):
969 def fuzzit(self, l, fuzz, toponly):
970 # this removes context lines from the top and bottom of list 'l'. It
970 # this removes context lines from the top and bottom of list 'l'. It
971 # checks the hunk to make sure only context lines are removed, and then
971 # checks the hunk to make sure only context lines are removed, and then
972 # returns a new shortened list of lines.
972 # returns a new shortened list of lines.
973 fuzz = min(fuzz, len(l)-1)
973 fuzz = min(fuzz, len(l)-1)
974 if fuzz:
974 if fuzz:
975 top = 0
975 top = 0
976 bot = 0
976 bot = 0
977 hlen = len(self.hunk)
977 hlen = len(self.hunk)
978 for x in xrange(hlen - 1):
978 for x in xrange(hlen - 1):
979 # the hunk starts with the @@ line, so use x+1
979 # the hunk starts with the @@ line, so use x+1
980 if self.hunk[x + 1][0] == ' ':
980 if self.hunk[x + 1][0] == ' ':
981 top += 1
981 top += 1
982 else:
982 else:
983 break
983 break
984 if not toponly:
984 if not toponly:
985 for x in xrange(hlen - 1):
985 for x in xrange(hlen - 1):
986 if self.hunk[hlen - bot - 1][0] == ' ':
986 if self.hunk[hlen - bot - 1][0] == ' ':
987 bot += 1
987 bot += 1
988 else:
988 else:
989 break
989 break
990
990
991 # top and bot now count context in the hunk
991 # top and bot now count context in the hunk
992 # adjust them if either one is short
992 # adjust them if either one is short
993 context = max(top, bot, 3)
993 context = max(top, bot, 3)
994 if bot < context:
994 if bot < context:
995 bot = max(0, fuzz - (context - bot))
995 bot = max(0, fuzz - (context - bot))
996 else:
996 else:
997 bot = min(fuzz, bot)
997 bot = min(fuzz, bot)
998 if top < context:
998 if top < context:
999 top = max(0, fuzz - (context - top))
999 top = max(0, fuzz - (context - top))
1000 else:
1000 else:
1001 top = min(fuzz, top)
1001 top = min(fuzz, top)
1002
1002
1003 return l[top:len(l)-bot]
1003 return l[top:len(l)-bot]
1004 return l
1004 return l
1005
1005
1006 def old(self, fuzz=0, toponly=False):
1006 def old(self, fuzz=0, toponly=False):
1007 return self.fuzzit(self.a, fuzz, toponly)
1007 return self.fuzzit(self.a, fuzz, toponly)
1008
1008
1009 def new(self, fuzz=0, toponly=False):
1009 def new(self, fuzz=0, toponly=False):
1010 return self.fuzzit(self.b, fuzz, toponly)
1010 return self.fuzzit(self.b, fuzz, toponly)
1011
1011
1012 class binhunk:
1012 class binhunk(object):
1013 'A binary patch file. Only understands literals so far.'
1013 'A binary patch file. Only understands literals so far.'
1014 def __init__(self, lr):
1014 def __init__(self, lr):
1015 self.text = None
1015 self.text = None
1016 self.hunk = ['GIT binary patch\n']
1016 self.hunk = ['GIT binary patch\n']
1017 self._read(lr)
1017 self._read(lr)
1018
1018
1019 def complete(self):
1019 def complete(self):
1020 return self.text is not None
1020 return self.text is not None
1021
1021
1022 def new(self):
1022 def new(self):
1023 return [self.text]
1023 return [self.text]
1024
1024
1025 def _read(self, lr):
1025 def _read(self, lr):
1026 line = lr.readline()
1026 line = lr.readline()
1027 self.hunk.append(line)
1027 self.hunk.append(line)
1028 while line and not line.startswith('literal '):
1028 while line and not line.startswith('literal '):
1029 line = lr.readline()
1029 line = lr.readline()
1030 self.hunk.append(line)
1030 self.hunk.append(line)
1031 if not line:
1031 if not line:
1032 raise PatchError(_('could not extract binary patch'))
1032 raise PatchError(_('could not extract binary patch'))
1033 size = int(line[8:].rstrip())
1033 size = int(line[8:].rstrip())
1034 dec = []
1034 dec = []
1035 line = lr.readline()
1035 line = lr.readline()
1036 self.hunk.append(line)
1036 self.hunk.append(line)
1037 while len(line) > 1:
1037 while len(line) > 1:
1038 l = line[0]
1038 l = line[0]
1039 if l <= 'Z' and l >= 'A':
1039 if l <= 'Z' and l >= 'A':
1040 l = ord(l) - ord('A') + 1
1040 l = ord(l) - ord('A') + 1
1041 else:
1041 else:
1042 l = ord(l) - ord('a') + 27
1042 l = ord(l) - ord('a') + 27
1043 dec.append(base85.b85decode(line[1:-1])[:l])
1043 dec.append(base85.b85decode(line[1:-1])[:l])
1044 line = lr.readline()
1044 line = lr.readline()
1045 self.hunk.append(line)
1045 self.hunk.append(line)
1046 text = zlib.decompress(''.join(dec))
1046 text = zlib.decompress(''.join(dec))
1047 if len(text) != size:
1047 if len(text) != size:
1048 raise PatchError(_('binary patch is %d bytes, not %d') %
1048 raise PatchError(_('binary patch is %d bytes, not %d') %
1049 len(text), size)
1049 len(text), size)
1050 self.text = text
1050 self.text = text
1051
1051
1052 def parsefilename(str):
1052 def parsefilename(str):
1053 # --- filename \t|space stuff
1053 # --- filename \t|space stuff
1054 s = str[4:].rstrip('\r\n')
1054 s = str[4:].rstrip('\r\n')
1055 i = s.find('\t')
1055 i = s.find('\t')
1056 if i < 0:
1056 if i < 0:
1057 i = s.find(' ')
1057 i = s.find(' ')
1058 if i < 0:
1058 if i < 0:
1059 return s
1059 return s
1060 return s[:i]
1060 return s[:i]
1061
1061
1062 def pathstrip(path, strip):
1062 def pathstrip(path, strip):
1063 pathlen = len(path)
1063 pathlen = len(path)
1064 i = 0
1064 i = 0
1065 if strip == 0:
1065 if strip == 0:
1066 return '', path.rstrip()
1066 return '', path.rstrip()
1067 count = strip
1067 count = strip
1068 while count > 0:
1068 while count > 0:
1069 i = path.find('/', i)
1069 i = path.find('/', i)
1070 if i == -1:
1070 if i == -1:
1071 raise PatchError(_("unable to strip away %d of %d dirs from %s") %
1071 raise PatchError(_("unable to strip away %d of %d dirs from %s") %
1072 (count, strip, path))
1072 (count, strip, path))
1073 i += 1
1073 i += 1
1074 # consume '//' in the path
1074 # consume '//' in the path
1075 while i < pathlen - 1 and path[i] == '/':
1075 while i < pathlen - 1 and path[i] == '/':
1076 i += 1
1076 i += 1
1077 count -= 1
1077 count -= 1
1078 return path[:i].lstrip(), path[i:].rstrip()
1078 return path[:i].lstrip(), path[i:].rstrip()
1079
1079
1080 def makepatchmeta(backend, afile_orig, bfile_orig, hunk, strip):
1080 def makepatchmeta(backend, afile_orig, bfile_orig, hunk, strip):
1081 nulla = afile_orig == "/dev/null"
1081 nulla = afile_orig == "/dev/null"
1082 nullb = bfile_orig == "/dev/null"
1082 nullb = bfile_orig == "/dev/null"
1083 create = nulla and hunk.starta == 0 and hunk.lena == 0
1083 create = nulla and hunk.starta == 0 and hunk.lena == 0
1084 remove = nullb and hunk.startb == 0 and hunk.lenb == 0
1084 remove = nullb and hunk.startb == 0 and hunk.lenb == 0
1085 abase, afile = pathstrip(afile_orig, strip)
1085 abase, afile = pathstrip(afile_orig, strip)
1086 gooda = not nulla and backend.exists(afile)
1086 gooda = not nulla and backend.exists(afile)
1087 bbase, bfile = pathstrip(bfile_orig, strip)
1087 bbase, bfile = pathstrip(bfile_orig, strip)
1088 if afile == bfile:
1088 if afile == bfile:
1089 goodb = gooda
1089 goodb = gooda
1090 else:
1090 else:
1091 goodb = not nullb and backend.exists(bfile)
1091 goodb = not nullb and backend.exists(bfile)
1092 missing = not goodb and not gooda and not create
1092 missing = not goodb and not gooda and not create
1093
1093
1094 # some diff programs apparently produce patches where the afile is
1094 # some diff programs apparently produce patches where the afile is
1095 # not /dev/null, but afile starts with bfile
1095 # not /dev/null, but afile starts with bfile
1096 abasedir = afile[:afile.rfind('/') + 1]
1096 abasedir = afile[:afile.rfind('/') + 1]
1097 bbasedir = bfile[:bfile.rfind('/') + 1]
1097 bbasedir = bfile[:bfile.rfind('/') + 1]
1098 if (missing and abasedir == bbasedir and afile.startswith(bfile)
1098 if (missing and abasedir == bbasedir and afile.startswith(bfile)
1099 and hunk.starta == 0 and hunk.lena == 0):
1099 and hunk.starta == 0 and hunk.lena == 0):
1100 create = True
1100 create = True
1101 missing = False
1101 missing = False
1102
1102
1103 # If afile is "a/b/foo" and bfile is "a/b/foo.orig" we assume the
1103 # If afile is "a/b/foo" and bfile is "a/b/foo.orig" we assume the
1104 # diff is between a file and its backup. In this case, the original
1104 # diff is between a file and its backup. In this case, the original
1105 # file should be patched (see original mpatch code).
1105 # file should be patched (see original mpatch code).
1106 isbackup = (abase == bbase and bfile.startswith(afile))
1106 isbackup = (abase == bbase and bfile.startswith(afile))
1107 fname = None
1107 fname = None
1108 if not missing:
1108 if not missing:
1109 if gooda and goodb:
1109 if gooda and goodb:
1110 fname = isbackup and afile or bfile
1110 fname = isbackup and afile or bfile
1111 elif gooda:
1111 elif gooda:
1112 fname = afile
1112 fname = afile
1113
1113
1114 if not fname:
1114 if not fname:
1115 if not nullb:
1115 if not nullb:
1116 fname = isbackup and afile or bfile
1116 fname = isbackup and afile or bfile
1117 elif not nulla:
1117 elif not nulla:
1118 fname = afile
1118 fname = afile
1119 else:
1119 else:
1120 raise PatchError(_("undefined source and destination files"))
1120 raise PatchError(_("undefined source and destination files"))
1121
1121
1122 gp = patchmeta(fname)
1122 gp = patchmeta(fname)
1123 if create:
1123 if create:
1124 gp.op = 'ADD'
1124 gp.op = 'ADD'
1125 elif remove:
1125 elif remove:
1126 gp.op = 'DELETE'
1126 gp.op = 'DELETE'
1127 return gp
1127 return gp
1128
1128
1129 def scangitpatch(lr, firstline):
1129 def scangitpatch(lr, firstline):
1130 """
1130 """
1131 Git patches can emit:
1131 Git patches can emit:
1132 - rename a to b
1132 - rename a to b
1133 - change b
1133 - change b
1134 - copy a to c
1134 - copy a to c
1135 - change c
1135 - change c
1136
1136
1137 We cannot apply this sequence as-is, the renamed 'a' could not be
1137 We cannot apply this sequence as-is, the renamed 'a' could not be
1138 found for it would have been renamed already. And we cannot copy
1138 found for it would have been renamed already. And we cannot copy
1139 from 'b' instead because 'b' would have been changed already. So
1139 from 'b' instead because 'b' would have been changed already. So
1140 we scan the git patch for copy and rename commands so we can
1140 we scan the git patch for copy and rename commands so we can
1141 perform the copies ahead of time.
1141 perform the copies ahead of time.
1142 """
1142 """
1143 pos = 0
1143 pos = 0
1144 try:
1144 try:
1145 pos = lr.fp.tell()
1145 pos = lr.fp.tell()
1146 fp = lr.fp
1146 fp = lr.fp
1147 except IOError:
1147 except IOError:
1148 fp = cStringIO.StringIO(lr.fp.read())
1148 fp = cStringIO.StringIO(lr.fp.read())
1149 gitlr = linereader(fp)
1149 gitlr = linereader(fp)
1150 gitlr.push(firstline)
1150 gitlr.push(firstline)
1151 gitpatches = readgitpatch(gitlr)
1151 gitpatches = readgitpatch(gitlr)
1152 fp.seek(pos)
1152 fp.seek(pos)
1153 return gitpatches
1153 return gitpatches
1154
1154
1155 def iterhunks(fp):
1155 def iterhunks(fp):
1156 """Read a patch and yield the following events:
1156 """Read a patch and yield the following events:
1157 - ("file", afile, bfile, firsthunk): select a new target file.
1157 - ("file", afile, bfile, firsthunk): select a new target file.
1158 - ("hunk", hunk): a new hunk is ready to be applied, follows a
1158 - ("hunk", hunk): a new hunk is ready to be applied, follows a
1159 "file" event.
1159 "file" event.
1160 - ("git", gitchanges): current diff is in git format, gitchanges
1160 - ("git", gitchanges): current diff is in git format, gitchanges
1161 maps filenames to gitpatch records. Unique event.
1161 maps filenames to gitpatch records. Unique event.
1162 """
1162 """
1163 afile = ""
1163 afile = ""
1164 bfile = ""
1164 bfile = ""
1165 state = None
1165 state = None
1166 hunknum = 0
1166 hunknum = 0
1167 emitfile = newfile = False
1167 emitfile = newfile = False
1168 gitpatches = None
1168 gitpatches = None
1169
1169
1170 # our states
1170 # our states
1171 BFILE = 1
1171 BFILE = 1
1172 context = None
1172 context = None
1173 lr = linereader(fp)
1173 lr = linereader(fp)
1174
1174
1175 while True:
1175 while True:
1176 x = lr.readline()
1176 x = lr.readline()
1177 if not x:
1177 if not x:
1178 break
1178 break
1179 if state == BFILE and (
1179 if state == BFILE and (
1180 (not context and x[0] == '@')
1180 (not context and x[0] == '@')
1181 or (context is not False and x.startswith('***************'))
1181 or (context is not False and x.startswith('***************'))
1182 or x.startswith('GIT binary patch')):
1182 or x.startswith('GIT binary patch')):
1183 gp = None
1183 gp = None
1184 if (gitpatches and
1184 if (gitpatches and
1185 (gitpatches[-1][0] == afile or gitpatches[-1][1] == bfile)):
1185 (gitpatches[-1][0] == afile or gitpatches[-1][1] == bfile)):
1186 gp = gitpatches.pop()[2]
1186 gp = gitpatches.pop()[2]
1187 if x.startswith('GIT binary patch'):
1187 if x.startswith('GIT binary patch'):
1188 h = binhunk(lr)
1188 h = binhunk(lr)
1189 else:
1189 else:
1190 if context is None and x.startswith('***************'):
1190 if context is None and x.startswith('***************'):
1191 context = True
1191 context = True
1192 h = hunk(x, hunknum + 1, lr, context)
1192 h = hunk(x, hunknum + 1, lr, context)
1193 hunknum += 1
1193 hunknum += 1
1194 if emitfile:
1194 if emitfile:
1195 emitfile = False
1195 emitfile = False
1196 yield 'file', (afile, bfile, h, gp and gp.copy() or None)
1196 yield 'file', (afile, bfile, h, gp and gp.copy() or None)
1197 yield 'hunk', h
1197 yield 'hunk', h
1198 elif x.startswith('diff --git'):
1198 elif x.startswith('diff --git'):
1199 m = gitre.match(x)
1199 m = gitre.match(x)
1200 if not m:
1200 if not m:
1201 continue
1201 continue
1202 if gitpatches is None:
1202 if gitpatches is None:
1203 # scan whole input for git metadata
1203 # scan whole input for git metadata
1204 gitpatches = [('a/' + gp.path, 'b/' + gp.path, gp) for gp
1204 gitpatches = [('a/' + gp.path, 'b/' + gp.path, gp) for gp
1205 in scangitpatch(lr, x)]
1205 in scangitpatch(lr, x)]
1206 yield 'git', [g[2].copy() for g in gitpatches
1206 yield 'git', [g[2].copy() for g in gitpatches
1207 if g[2].op in ('COPY', 'RENAME')]
1207 if g[2].op in ('COPY', 'RENAME')]
1208 gitpatches.reverse()
1208 gitpatches.reverse()
1209 afile = 'a/' + m.group(1)
1209 afile = 'a/' + m.group(1)
1210 bfile = 'b/' + m.group(2)
1210 bfile = 'b/' + m.group(2)
1211 while afile != gitpatches[-1][0] and bfile != gitpatches[-1][1]:
1211 while afile != gitpatches[-1][0] and bfile != gitpatches[-1][1]:
1212 gp = gitpatches.pop()[2]
1212 gp = gitpatches.pop()[2]
1213 yield 'file', ('a/' + gp.path, 'b/' + gp.path, None, gp.copy())
1213 yield 'file', ('a/' + gp.path, 'b/' + gp.path, None, gp.copy())
1214 gp = gitpatches[-1][2]
1214 gp = gitpatches[-1][2]
1215 # copy/rename + modify should modify target, not source
1215 # copy/rename + modify should modify target, not source
1216 if gp.op in ('COPY', 'DELETE', 'RENAME', 'ADD') or gp.mode:
1216 if gp.op in ('COPY', 'DELETE', 'RENAME', 'ADD') or gp.mode:
1217 afile = bfile
1217 afile = bfile
1218 newfile = True
1218 newfile = True
1219 elif x.startswith('---'):
1219 elif x.startswith('---'):
1220 # check for a unified diff
1220 # check for a unified diff
1221 l2 = lr.readline()
1221 l2 = lr.readline()
1222 if not l2.startswith('+++'):
1222 if not l2.startswith('+++'):
1223 lr.push(l2)
1223 lr.push(l2)
1224 continue
1224 continue
1225 newfile = True
1225 newfile = True
1226 context = False
1226 context = False
1227 afile = parsefilename(x)
1227 afile = parsefilename(x)
1228 bfile = parsefilename(l2)
1228 bfile = parsefilename(l2)
1229 elif x.startswith('***'):
1229 elif x.startswith('***'):
1230 # check for a context diff
1230 # check for a context diff
1231 l2 = lr.readline()
1231 l2 = lr.readline()
1232 if not l2.startswith('---'):
1232 if not l2.startswith('---'):
1233 lr.push(l2)
1233 lr.push(l2)
1234 continue
1234 continue
1235 l3 = lr.readline()
1235 l3 = lr.readline()
1236 lr.push(l3)
1236 lr.push(l3)
1237 if not l3.startswith("***************"):
1237 if not l3.startswith("***************"):
1238 lr.push(l2)
1238 lr.push(l2)
1239 continue
1239 continue
1240 newfile = True
1240 newfile = True
1241 context = True
1241 context = True
1242 afile = parsefilename(x)
1242 afile = parsefilename(x)
1243 bfile = parsefilename(l2)
1243 bfile = parsefilename(l2)
1244
1244
1245 if newfile:
1245 if newfile:
1246 newfile = False
1246 newfile = False
1247 emitfile = True
1247 emitfile = True
1248 state = BFILE
1248 state = BFILE
1249 hunknum = 0
1249 hunknum = 0
1250
1250
1251 while gitpatches:
1251 while gitpatches:
1252 gp = gitpatches.pop()[2]
1252 gp = gitpatches.pop()[2]
1253 yield 'file', ('a/' + gp.path, 'b/' + gp.path, None, gp.copy())
1253 yield 'file', ('a/' + gp.path, 'b/' + gp.path, None, gp.copy())
1254
1254
1255 def applydiff(ui, fp, backend, store, strip=1, eolmode='strict'):
1255 def applydiff(ui, fp, backend, store, strip=1, eolmode='strict'):
1256 """Reads a patch from fp and tries to apply it.
1256 """Reads a patch from fp and tries to apply it.
1257
1257
1258 Returns 0 for a clean patch, -1 if any rejects were found and 1 if
1258 Returns 0 for a clean patch, -1 if any rejects were found and 1 if
1259 there was any fuzz.
1259 there was any fuzz.
1260
1260
1261 If 'eolmode' is 'strict', the patch content and patched file are
1261 If 'eolmode' is 'strict', the patch content and patched file are
1262 read in binary mode. Otherwise, line endings are ignored when
1262 read in binary mode. Otherwise, line endings are ignored when
1263 patching then normalized according to 'eolmode'.
1263 patching then normalized according to 'eolmode'.
1264 """
1264 """
1265 return _applydiff(ui, fp, patchfile, backend, store, strip=strip,
1265 return _applydiff(ui, fp, patchfile, backend, store, strip=strip,
1266 eolmode=eolmode)
1266 eolmode=eolmode)
1267
1267
1268 def _applydiff(ui, fp, patcher, backend, store, strip=1,
1268 def _applydiff(ui, fp, patcher, backend, store, strip=1,
1269 eolmode='strict'):
1269 eolmode='strict'):
1270
1270
1271 def pstrip(p):
1271 def pstrip(p):
1272 return pathstrip(p, strip - 1)[1]
1272 return pathstrip(p, strip - 1)[1]
1273
1273
1274 rejects = 0
1274 rejects = 0
1275 err = 0
1275 err = 0
1276 current_file = None
1276 current_file = None
1277
1277
1278 for state, values in iterhunks(fp):
1278 for state, values in iterhunks(fp):
1279 if state == 'hunk':
1279 if state == 'hunk':
1280 if not current_file:
1280 if not current_file:
1281 continue
1281 continue
1282 ret = current_file.apply(values)
1282 ret = current_file.apply(values)
1283 if ret > 0:
1283 if ret > 0:
1284 err = 1
1284 err = 1
1285 elif state == 'file':
1285 elif state == 'file':
1286 if current_file:
1286 if current_file:
1287 rejects += current_file.close()
1287 rejects += current_file.close()
1288 current_file = None
1288 current_file = None
1289 afile, bfile, first_hunk, gp = values
1289 afile, bfile, first_hunk, gp = values
1290 if gp:
1290 if gp:
1291 path = pstrip(gp.path)
1291 path = pstrip(gp.path)
1292 gp.path = pstrip(gp.path)
1292 gp.path = pstrip(gp.path)
1293 if gp.oldpath:
1293 if gp.oldpath:
1294 gp.oldpath = pstrip(gp.oldpath)
1294 gp.oldpath = pstrip(gp.oldpath)
1295 else:
1295 else:
1296 gp = makepatchmeta(backend, afile, bfile, first_hunk, strip)
1296 gp = makepatchmeta(backend, afile, bfile, first_hunk, strip)
1297 if gp.op == 'RENAME':
1297 if gp.op == 'RENAME':
1298 backend.unlink(gp.oldpath)
1298 backend.unlink(gp.oldpath)
1299 if not first_hunk:
1299 if not first_hunk:
1300 if gp.op == 'DELETE':
1300 if gp.op == 'DELETE':
1301 backend.unlink(gp.path)
1301 backend.unlink(gp.path)
1302 continue
1302 continue
1303 data, mode = None, None
1303 data, mode = None, None
1304 if gp.op in ('RENAME', 'COPY'):
1304 if gp.op in ('RENAME', 'COPY'):
1305 data, mode = store.getfile(gp.oldpath)[:2]
1305 data, mode = store.getfile(gp.oldpath)[:2]
1306 if gp.mode:
1306 if gp.mode:
1307 mode = gp.mode
1307 mode = gp.mode
1308 if gp.op == 'ADD':
1308 if gp.op == 'ADD':
1309 # Added files without content have no hunk and
1309 # Added files without content have no hunk and
1310 # must be created
1310 # must be created
1311 data = ''
1311 data = ''
1312 if data or mode:
1312 if data or mode:
1313 if (gp.op in ('ADD', 'RENAME', 'COPY')
1313 if (gp.op in ('ADD', 'RENAME', 'COPY')
1314 and backend.exists(gp.path)):
1314 and backend.exists(gp.path)):
1315 raise PatchError(_("cannot create %s: destination "
1315 raise PatchError(_("cannot create %s: destination "
1316 "already exists") % gp.path)
1316 "already exists") % gp.path)
1317 backend.setfile(gp.path, data, mode, gp.oldpath)
1317 backend.setfile(gp.path, data, mode, gp.oldpath)
1318 continue
1318 continue
1319 try:
1319 try:
1320 current_file = patcher(ui, gp, backend, store,
1320 current_file = patcher(ui, gp, backend, store,
1321 eolmode=eolmode)
1321 eolmode=eolmode)
1322 except PatchError, inst:
1322 except PatchError, inst:
1323 ui.warn(str(inst) + '\n')
1323 ui.warn(str(inst) + '\n')
1324 current_file = None
1324 current_file = None
1325 rejects += 1
1325 rejects += 1
1326 continue
1326 continue
1327 elif state == 'git':
1327 elif state == 'git':
1328 for gp in values:
1328 for gp in values:
1329 path = pstrip(gp.oldpath)
1329 path = pstrip(gp.oldpath)
1330 data, mode = backend.getfile(path)
1330 data, mode = backend.getfile(path)
1331 store.setfile(path, data, mode)
1331 store.setfile(path, data, mode)
1332 else:
1332 else:
1333 raise util.Abort(_('unsupported parser state: %s') % state)
1333 raise util.Abort(_('unsupported parser state: %s') % state)
1334
1334
1335 if current_file:
1335 if current_file:
1336 rejects += current_file.close()
1336 rejects += current_file.close()
1337
1337
1338 if rejects:
1338 if rejects:
1339 return -1
1339 return -1
1340 return err
1340 return err
1341
1341
1342 def _externalpatch(ui, repo, patcher, patchname, strip, files,
1342 def _externalpatch(ui, repo, patcher, patchname, strip, files,
1343 similarity):
1343 similarity):
1344 """use <patcher> to apply <patchname> to the working directory.
1344 """use <patcher> to apply <patchname> to the working directory.
1345 returns whether patch was applied with fuzz factor."""
1345 returns whether patch was applied with fuzz factor."""
1346
1346
1347 fuzz = False
1347 fuzz = False
1348 args = []
1348 args = []
1349 cwd = repo.root
1349 cwd = repo.root
1350 if cwd:
1350 if cwd:
1351 args.append('-d %s' % util.shellquote(cwd))
1351 args.append('-d %s' % util.shellquote(cwd))
1352 fp = util.popen('%s %s -p%d < %s' % (patcher, ' '.join(args), strip,
1352 fp = util.popen('%s %s -p%d < %s' % (patcher, ' '.join(args), strip,
1353 util.shellquote(patchname)))
1353 util.shellquote(patchname)))
1354 try:
1354 try:
1355 for line in fp:
1355 for line in fp:
1356 line = line.rstrip()
1356 line = line.rstrip()
1357 ui.note(line + '\n')
1357 ui.note(line + '\n')
1358 if line.startswith('patching file '):
1358 if line.startswith('patching file '):
1359 pf = util.parsepatchoutput(line)
1359 pf = util.parsepatchoutput(line)
1360 printed_file = False
1360 printed_file = False
1361 files.add(pf)
1361 files.add(pf)
1362 elif line.find('with fuzz') >= 0:
1362 elif line.find('with fuzz') >= 0:
1363 fuzz = True
1363 fuzz = True
1364 if not printed_file:
1364 if not printed_file:
1365 ui.warn(pf + '\n')
1365 ui.warn(pf + '\n')
1366 printed_file = True
1366 printed_file = True
1367 ui.warn(line + '\n')
1367 ui.warn(line + '\n')
1368 elif line.find('saving rejects to file') >= 0:
1368 elif line.find('saving rejects to file') >= 0:
1369 ui.warn(line + '\n')
1369 ui.warn(line + '\n')
1370 elif line.find('FAILED') >= 0:
1370 elif line.find('FAILED') >= 0:
1371 if not printed_file:
1371 if not printed_file:
1372 ui.warn(pf + '\n')
1372 ui.warn(pf + '\n')
1373 printed_file = True
1373 printed_file = True
1374 ui.warn(line + '\n')
1374 ui.warn(line + '\n')
1375 finally:
1375 finally:
1376 if files:
1376 if files:
1377 cfiles = list(files)
1377 cfiles = list(files)
1378 cwd = repo.getcwd()
1378 cwd = repo.getcwd()
1379 if cwd:
1379 if cwd:
1380 cfiles = [util.pathto(repo.root, cwd, f)
1380 cfiles = [util.pathto(repo.root, cwd, f)
1381 for f in cfiles]
1381 for f in cfiles]
1382 scmutil.addremove(repo, cfiles, similarity=similarity)
1382 scmutil.addremove(repo, cfiles, similarity=similarity)
1383 code = fp.close()
1383 code = fp.close()
1384 if code:
1384 if code:
1385 raise PatchError(_("patch command failed: %s") %
1385 raise PatchError(_("patch command failed: %s") %
1386 util.explainexit(code)[0])
1386 util.explainexit(code)[0])
1387 return fuzz
1387 return fuzz
1388
1388
1389 def patchbackend(ui, backend, patchobj, strip, files=None, eolmode='strict'):
1389 def patchbackend(ui, backend, patchobj, strip, files=None, eolmode='strict'):
1390 if files is None:
1390 if files is None:
1391 files = set()
1391 files = set()
1392 if eolmode is None:
1392 if eolmode is None:
1393 eolmode = ui.config('patch', 'eol', 'strict')
1393 eolmode = ui.config('patch', 'eol', 'strict')
1394 if eolmode.lower() not in eolmodes:
1394 if eolmode.lower() not in eolmodes:
1395 raise util.Abort(_('unsupported line endings type: %s') % eolmode)
1395 raise util.Abort(_('unsupported line endings type: %s') % eolmode)
1396 eolmode = eolmode.lower()
1396 eolmode = eolmode.lower()
1397
1397
1398 store = filestore()
1398 store = filestore()
1399 try:
1399 try:
1400 fp = open(patchobj, 'rb')
1400 fp = open(patchobj, 'rb')
1401 except TypeError:
1401 except TypeError:
1402 fp = patchobj
1402 fp = patchobj
1403 try:
1403 try:
1404 ret = applydiff(ui, fp, backend, store, strip=strip,
1404 ret = applydiff(ui, fp, backend, store, strip=strip,
1405 eolmode=eolmode)
1405 eolmode=eolmode)
1406 finally:
1406 finally:
1407 if fp != patchobj:
1407 if fp != patchobj:
1408 fp.close()
1408 fp.close()
1409 files.update(backend.close())
1409 files.update(backend.close())
1410 store.close()
1410 store.close()
1411 if ret < 0:
1411 if ret < 0:
1412 raise PatchError(_('patch failed to apply'))
1412 raise PatchError(_('patch failed to apply'))
1413 return ret > 0
1413 return ret > 0
1414
1414
1415 def internalpatch(ui, repo, patchobj, strip, files=None, eolmode='strict',
1415 def internalpatch(ui, repo, patchobj, strip, files=None, eolmode='strict',
1416 similarity=0):
1416 similarity=0):
1417 """use builtin patch to apply <patchobj> to the working directory.
1417 """use builtin patch to apply <patchobj> to the working directory.
1418 returns whether patch was applied with fuzz factor."""
1418 returns whether patch was applied with fuzz factor."""
1419 backend = workingbackend(ui, repo, similarity)
1419 backend = workingbackend(ui, repo, similarity)
1420 return patchbackend(ui, backend, patchobj, strip, files, eolmode)
1420 return patchbackend(ui, backend, patchobj, strip, files, eolmode)
1421
1421
1422 def patchrepo(ui, repo, ctx, store, patchobj, strip, files=None,
1422 def patchrepo(ui, repo, ctx, store, patchobj, strip, files=None,
1423 eolmode='strict'):
1423 eolmode='strict'):
1424 backend = repobackend(ui, repo, ctx, store)
1424 backend = repobackend(ui, repo, ctx, store)
1425 return patchbackend(ui, backend, patchobj, strip, files, eolmode)
1425 return patchbackend(ui, backend, patchobj, strip, files, eolmode)
1426
1426
1427 def makememctx(repo, parents, text, user, date, branch, files, store,
1427 def makememctx(repo, parents, text, user, date, branch, files, store,
1428 editor=None):
1428 editor=None):
1429 def getfilectx(repo, memctx, path):
1429 def getfilectx(repo, memctx, path):
1430 data, (islink, isexec), copied = store.getfile(path)
1430 data, (islink, isexec), copied = store.getfile(path)
1431 return context.memfilectx(path, data, islink=islink, isexec=isexec,
1431 return context.memfilectx(path, data, islink=islink, isexec=isexec,
1432 copied=copied)
1432 copied=copied)
1433 extra = {}
1433 extra = {}
1434 if branch:
1434 if branch:
1435 extra['branch'] = encoding.fromlocal(branch)
1435 extra['branch'] = encoding.fromlocal(branch)
1436 ctx = context.memctx(repo, parents, text, files, getfilectx, user,
1436 ctx = context.memctx(repo, parents, text, files, getfilectx, user,
1437 date, extra)
1437 date, extra)
1438 if editor:
1438 if editor:
1439 ctx._text = editor(repo, ctx, [])
1439 ctx._text = editor(repo, ctx, [])
1440 return ctx
1440 return ctx
1441
1441
1442 def patch(ui, repo, patchname, strip=1, files=None, eolmode='strict',
1442 def patch(ui, repo, patchname, strip=1, files=None, eolmode='strict',
1443 similarity=0):
1443 similarity=0):
1444 """Apply <patchname> to the working directory.
1444 """Apply <patchname> to the working directory.
1445
1445
1446 'eolmode' specifies how end of lines should be handled. It can be:
1446 'eolmode' specifies how end of lines should be handled. It can be:
1447 - 'strict': inputs are read in binary mode, EOLs are preserved
1447 - 'strict': inputs are read in binary mode, EOLs are preserved
1448 - 'crlf': EOLs are ignored when patching and reset to CRLF
1448 - 'crlf': EOLs are ignored when patching and reset to CRLF
1449 - 'lf': EOLs are ignored when patching and reset to LF
1449 - 'lf': EOLs are ignored when patching and reset to LF
1450 - None: get it from user settings, default to 'strict'
1450 - None: get it from user settings, default to 'strict'
1451 'eolmode' is ignored when using an external patcher program.
1451 'eolmode' is ignored when using an external patcher program.
1452
1452
1453 Returns whether patch was applied with fuzz factor.
1453 Returns whether patch was applied with fuzz factor.
1454 """
1454 """
1455 patcher = ui.config('ui', 'patch')
1455 patcher = ui.config('ui', 'patch')
1456 if files is None:
1456 if files is None:
1457 files = set()
1457 files = set()
1458 try:
1458 try:
1459 if patcher:
1459 if patcher:
1460 return _externalpatch(ui, repo, patcher, patchname, strip,
1460 return _externalpatch(ui, repo, patcher, patchname, strip,
1461 files, similarity)
1461 files, similarity)
1462 return internalpatch(ui, repo, patchname, strip, files, eolmode,
1462 return internalpatch(ui, repo, patchname, strip, files, eolmode,
1463 similarity)
1463 similarity)
1464 except PatchError, err:
1464 except PatchError, err:
1465 raise util.Abort(str(err))
1465 raise util.Abort(str(err))
1466
1466
1467 def changedfiles(ui, repo, patchpath, strip=1):
1467 def changedfiles(ui, repo, patchpath, strip=1):
1468 backend = fsbackend(ui, repo.root)
1468 backend = fsbackend(ui, repo.root)
1469 fp = open(patchpath, 'rb')
1469 fp = open(patchpath, 'rb')
1470 try:
1470 try:
1471 changed = set()
1471 changed = set()
1472 for state, values in iterhunks(fp):
1472 for state, values in iterhunks(fp):
1473 if state == 'file':
1473 if state == 'file':
1474 afile, bfile, first_hunk, gp = values
1474 afile, bfile, first_hunk, gp = values
1475 if gp:
1475 if gp:
1476 gp.path = pathstrip(gp.path, strip - 1)[1]
1476 gp.path = pathstrip(gp.path, strip - 1)[1]
1477 if gp.oldpath:
1477 if gp.oldpath:
1478 gp.oldpath = pathstrip(gp.oldpath, strip - 1)[1]
1478 gp.oldpath = pathstrip(gp.oldpath, strip - 1)[1]
1479 else:
1479 else:
1480 gp = makepatchmeta(backend, afile, bfile, first_hunk, strip)
1480 gp = makepatchmeta(backend, afile, bfile, first_hunk, strip)
1481 changed.add(gp.path)
1481 changed.add(gp.path)
1482 if gp.op == 'RENAME':
1482 if gp.op == 'RENAME':
1483 changed.add(gp.oldpath)
1483 changed.add(gp.oldpath)
1484 elif state not in ('hunk', 'git'):
1484 elif state not in ('hunk', 'git'):
1485 raise util.Abort(_('unsupported parser state: %s') % state)
1485 raise util.Abort(_('unsupported parser state: %s') % state)
1486 return changed
1486 return changed
1487 finally:
1487 finally:
1488 fp.close()
1488 fp.close()
1489
1489
1490 def b85diff(to, tn):
1490 def b85diff(to, tn):
1491 '''print base85-encoded binary diff'''
1491 '''print base85-encoded binary diff'''
1492 def gitindex(text):
1492 def gitindex(text):
1493 if not text:
1493 if not text:
1494 return hex(nullid)
1494 return hex(nullid)
1495 l = len(text)
1495 l = len(text)
1496 s = util.sha1('blob %d\0' % l)
1496 s = util.sha1('blob %d\0' % l)
1497 s.update(text)
1497 s.update(text)
1498 return s.hexdigest()
1498 return s.hexdigest()
1499
1499
1500 def fmtline(line):
1500 def fmtline(line):
1501 l = len(line)
1501 l = len(line)
1502 if l <= 26:
1502 if l <= 26:
1503 l = chr(ord('A') + l - 1)
1503 l = chr(ord('A') + l - 1)
1504 else:
1504 else:
1505 l = chr(l - 26 + ord('a') - 1)
1505 l = chr(l - 26 + ord('a') - 1)
1506 return '%c%s\n' % (l, base85.b85encode(line, True))
1506 return '%c%s\n' % (l, base85.b85encode(line, True))
1507
1507
1508 def chunk(text, csize=52):
1508 def chunk(text, csize=52):
1509 l = len(text)
1509 l = len(text)
1510 i = 0
1510 i = 0
1511 while i < l:
1511 while i < l:
1512 yield text[i:i + csize]
1512 yield text[i:i + csize]
1513 i += csize
1513 i += csize
1514
1514
1515 tohash = gitindex(to)
1515 tohash = gitindex(to)
1516 tnhash = gitindex(tn)
1516 tnhash = gitindex(tn)
1517 if tohash == tnhash:
1517 if tohash == tnhash:
1518 return ""
1518 return ""
1519
1519
1520 # TODO: deltas
1520 # TODO: deltas
1521 ret = ['index %s..%s\nGIT binary patch\nliteral %s\n' %
1521 ret = ['index %s..%s\nGIT binary patch\nliteral %s\n' %
1522 (tohash, tnhash, len(tn))]
1522 (tohash, tnhash, len(tn))]
1523 for l in chunk(zlib.compress(tn)):
1523 for l in chunk(zlib.compress(tn)):
1524 ret.append(fmtline(l))
1524 ret.append(fmtline(l))
1525 ret.append('\n')
1525 ret.append('\n')
1526 return ''.join(ret)
1526 return ''.join(ret)
1527
1527
1528 class GitDiffRequired(Exception):
1528 class GitDiffRequired(Exception):
1529 pass
1529 pass
1530
1530
1531 def diffopts(ui, opts=None, untrusted=False):
1531 def diffopts(ui, opts=None, untrusted=False):
1532 def get(key, name=None, getter=ui.configbool):
1532 def get(key, name=None, getter=ui.configbool):
1533 return ((opts and opts.get(key)) or
1533 return ((opts and opts.get(key)) or
1534 getter('diff', name or key, None, untrusted=untrusted))
1534 getter('diff', name or key, None, untrusted=untrusted))
1535 return mdiff.diffopts(
1535 return mdiff.diffopts(
1536 text=opts and opts.get('text'),
1536 text=opts and opts.get('text'),
1537 git=get('git'),
1537 git=get('git'),
1538 nodates=get('nodates'),
1538 nodates=get('nodates'),
1539 showfunc=get('show_function', 'showfunc'),
1539 showfunc=get('show_function', 'showfunc'),
1540 ignorews=get('ignore_all_space', 'ignorews'),
1540 ignorews=get('ignore_all_space', 'ignorews'),
1541 ignorewsamount=get('ignore_space_change', 'ignorewsamount'),
1541 ignorewsamount=get('ignore_space_change', 'ignorewsamount'),
1542 ignoreblanklines=get('ignore_blank_lines', 'ignoreblanklines'),
1542 ignoreblanklines=get('ignore_blank_lines', 'ignoreblanklines'),
1543 context=get('unified', getter=ui.config))
1543 context=get('unified', getter=ui.config))
1544
1544
1545 def diff(repo, node1=None, node2=None, match=None, changes=None, opts=None,
1545 def diff(repo, node1=None, node2=None, match=None, changes=None, opts=None,
1546 losedatafn=None, prefix=''):
1546 losedatafn=None, prefix=''):
1547 '''yields diff of changes to files between two nodes, or node and
1547 '''yields diff of changes to files between two nodes, or node and
1548 working directory.
1548 working directory.
1549
1549
1550 if node1 is None, use first dirstate parent instead.
1550 if node1 is None, use first dirstate parent instead.
1551 if node2 is None, compare node1 with working directory.
1551 if node2 is None, compare node1 with working directory.
1552
1552
1553 losedatafn(**kwarg) is a callable run when opts.upgrade=True and
1553 losedatafn(**kwarg) is a callable run when opts.upgrade=True and
1554 every time some change cannot be represented with the current
1554 every time some change cannot be represented with the current
1555 patch format. Return False to upgrade to git patch format, True to
1555 patch format. Return False to upgrade to git patch format, True to
1556 accept the loss or raise an exception to abort the diff. It is
1556 accept the loss or raise an exception to abort the diff. It is
1557 called with the name of current file being diffed as 'fn'. If set
1557 called with the name of current file being diffed as 'fn'. If set
1558 to None, patches will always be upgraded to git format when
1558 to None, patches will always be upgraded to git format when
1559 necessary.
1559 necessary.
1560
1560
1561 prefix is a filename prefix that is prepended to all filenames on
1561 prefix is a filename prefix that is prepended to all filenames on
1562 display (used for subrepos).
1562 display (used for subrepos).
1563 '''
1563 '''
1564
1564
1565 if opts is None:
1565 if opts is None:
1566 opts = mdiff.defaultopts
1566 opts = mdiff.defaultopts
1567
1567
1568 if not node1 and not node2:
1568 if not node1 and not node2:
1569 node1 = repo.dirstate.p1()
1569 node1 = repo.dirstate.p1()
1570
1570
1571 def lrugetfilectx():
1571 def lrugetfilectx():
1572 cache = {}
1572 cache = {}
1573 order = []
1573 order = []
1574 def getfilectx(f, ctx):
1574 def getfilectx(f, ctx):
1575 fctx = ctx.filectx(f, filelog=cache.get(f))
1575 fctx = ctx.filectx(f, filelog=cache.get(f))
1576 if f not in cache:
1576 if f not in cache:
1577 if len(cache) > 20:
1577 if len(cache) > 20:
1578 del cache[order.pop(0)]
1578 del cache[order.pop(0)]
1579 cache[f] = fctx.filelog()
1579 cache[f] = fctx.filelog()
1580 else:
1580 else:
1581 order.remove(f)
1581 order.remove(f)
1582 order.append(f)
1582 order.append(f)
1583 return fctx
1583 return fctx
1584 return getfilectx
1584 return getfilectx
1585 getfilectx = lrugetfilectx()
1585 getfilectx = lrugetfilectx()
1586
1586
1587 ctx1 = repo[node1]
1587 ctx1 = repo[node1]
1588 ctx2 = repo[node2]
1588 ctx2 = repo[node2]
1589
1589
1590 if not changes:
1590 if not changes:
1591 changes = repo.status(ctx1, ctx2, match=match)
1591 changes = repo.status(ctx1, ctx2, match=match)
1592 modified, added, removed = changes[:3]
1592 modified, added, removed = changes[:3]
1593
1593
1594 if not modified and not added and not removed:
1594 if not modified and not added and not removed:
1595 return []
1595 return []
1596
1596
1597 revs = None
1597 revs = None
1598 if not repo.ui.quiet:
1598 if not repo.ui.quiet:
1599 hexfunc = repo.ui.debugflag and hex or short
1599 hexfunc = repo.ui.debugflag and hex or short
1600 revs = [hexfunc(node) for node in [node1, node2] if node]
1600 revs = [hexfunc(node) for node in [node1, node2] if node]
1601
1601
1602 copy = {}
1602 copy = {}
1603 if opts.git or opts.upgrade:
1603 if opts.git or opts.upgrade:
1604 copy = copies.copies(repo, ctx1, ctx2, repo[nullid])[0]
1604 copy = copies.copies(repo, ctx1, ctx2, repo[nullid])[0]
1605
1605
1606 difffn = lambda opts, losedata: trydiff(repo, revs, ctx1, ctx2,
1606 difffn = lambda opts, losedata: trydiff(repo, revs, ctx1, ctx2,
1607 modified, added, removed, copy, getfilectx, opts, losedata, prefix)
1607 modified, added, removed, copy, getfilectx, opts, losedata, prefix)
1608 if opts.upgrade and not opts.git:
1608 if opts.upgrade and not opts.git:
1609 try:
1609 try:
1610 def losedata(fn):
1610 def losedata(fn):
1611 if not losedatafn or not losedatafn(fn=fn):
1611 if not losedatafn or not losedatafn(fn=fn):
1612 raise GitDiffRequired()
1612 raise GitDiffRequired()
1613 # Buffer the whole output until we are sure it can be generated
1613 # Buffer the whole output until we are sure it can be generated
1614 return list(difffn(opts.copy(git=False), losedata))
1614 return list(difffn(opts.copy(git=False), losedata))
1615 except GitDiffRequired:
1615 except GitDiffRequired:
1616 return difffn(opts.copy(git=True), None)
1616 return difffn(opts.copy(git=True), None)
1617 else:
1617 else:
1618 return difffn(opts, None)
1618 return difffn(opts, None)
1619
1619
1620 def difflabel(func, *args, **kw):
1620 def difflabel(func, *args, **kw):
1621 '''yields 2-tuples of (output, label) based on the output of func()'''
1621 '''yields 2-tuples of (output, label) based on the output of func()'''
1622 prefixes = [('diff', 'diff.diffline'),
1622 prefixes = [('diff', 'diff.diffline'),
1623 ('copy', 'diff.extended'),
1623 ('copy', 'diff.extended'),
1624 ('rename', 'diff.extended'),
1624 ('rename', 'diff.extended'),
1625 ('old', 'diff.extended'),
1625 ('old', 'diff.extended'),
1626 ('new', 'diff.extended'),
1626 ('new', 'diff.extended'),
1627 ('deleted', 'diff.extended'),
1627 ('deleted', 'diff.extended'),
1628 ('---', 'diff.file_a'),
1628 ('---', 'diff.file_a'),
1629 ('+++', 'diff.file_b'),
1629 ('+++', 'diff.file_b'),
1630 ('@@', 'diff.hunk'),
1630 ('@@', 'diff.hunk'),
1631 ('-', 'diff.deleted'),
1631 ('-', 'diff.deleted'),
1632 ('+', 'diff.inserted')]
1632 ('+', 'diff.inserted')]
1633
1633
1634 for chunk in func(*args, **kw):
1634 for chunk in func(*args, **kw):
1635 lines = chunk.split('\n')
1635 lines = chunk.split('\n')
1636 for i, line in enumerate(lines):
1636 for i, line in enumerate(lines):
1637 if i != 0:
1637 if i != 0:
1638 yield ('\n', '')
1638 yield ('\n', '')
1639 stripline = line
1639 stripline = line
1640 if line and line[0] in '+-':
1640 if line and line[0] in '+-':
1641 # highlight trailing whitespace, but only in changed lines
1641 # highlight trailing whitespace, but only in changed lines
1642 stripline = line.rstrip()
1642 stripline = line.rstrip()
1643 for prefix, label in prefixes:
1643 for prefix, label in prefixes:
1644 if stripline.startswith(prefix):
1644 if stripline.startswith(prefix):
1645 yield (stripline, label)
1645 yield (stripline, label)
1646 break
1646 break
1647 else:
1647 else:
1648 yield (line, '')
1648 yield (line, '')
1649 if line != stripline:
1649 if line != stripline:
1650 yield (line[len(stripline):], 'diff.trailingwhitespace')
1650 yield (line[len(stripline):], 'diff.trailingwhitespace')
1651
1651
1652 def diffui(*args, **kw):
1652 def diffui(*args, **kw):
1653 '''like diff(), but yields 2-tuples of (output, label) for ui.write()'''
1653 '''like diff(), but yields 2-tuples of (output, label) for ui.write()'''
1654 return difflabel(diff, *args, **kw)
1654 return difflabel(diff, *args, **kw)
1655
1655
1656
1656
1657 def _addmodehdr(header, omode, nmode):
1657 def _addmodehdr(header, omode, nmode):
1658 if omode != nmode:
1658 if omode != nmode:
1659 header.append('old mode %s\n' % omode)
1659 header.append('old mode %s\n' % omode)
1660 header.append('new mode %s\n' % nmode)
1660 header.append('new mode %s\n' % nmode)
1661
1661
1662 def trydiff(repo, revs, ctx1, ctx2, modified, added, removed,
1662 def trydiff(repo, revs, ctx1, ctx2, modified, added, removed,
1663 copy, getfilectx, opts, losedatafn, prefix):
1663 copy, getfilectx, opts, losedatafn, prefix):
1664
1664
1665 def join(f):
1665 def join(f):
1666 return os.path.join(prefix, f)
1666 return os.path.join(prefix, f)
1667
1667
1668 date1 = util.datestr(ctx1.date())
1668 date1 = util.datestr(ctx1.date())
1669 man1 = ctx1.manifest()
1669 man1 = ctx1.manifest()
1670
1670
1671 gone = set()
1671 gone = set()
1672 gitmode = {'l': '120000', 'x': '100755', '': '100644'}
1672 gitmode = {'l': '120000', 'x': '100755', '': '100644'}
1673
1673
1674 copyto = dict([(v, k) for k, v in copy.items()])
1674 copyto = dict([(v, k) for k, v in copy.items()])
1675
1675
1676 if opts.git:
1676 if opts.git:
1677 revs = None
1677 revs = None
1678
1678
1679 for f in sorted(modified + added + removed):
1679 for f in sorted(modified + added + removed):
1680 to = None
1680 to = None
1681 tn = None
1681 tn = None
1682 dodiff = True
1682 dodiff = True
1683 header = []
1683 header = []
1684 if f in man1:
1684 if f in man1:
1685 to = getfilectx(f, ctx1).data()
1685 to = getfilectx(f, ctx1).data()
1686 if f not in removed:
1686 if f not in removed:
1687 tn = getfilectx(f, ctx2).data()
1687 tn = getfilectx(f, ctx2).data()
1688 a, b = f, f
1688 a, b = f, f
1689 if opts.git or losedatafn:
1689 if opts.git or losedatafn:
1690 if f in added:
1690 if f in added:
1691 mode = gitmode[ctx2.flags(f)]
1691 mode = gitmode[ctx2.flags(f)]
1692 if f in copy or f in copyto:
1692 if f in copy or f in copyto:
1693 if opts.git:
1693 if opts.git:
1694 if f in copy:
1694 if f in copy:
1695 a = copy[f]
1695 a = copy[f]
1696 else:
1696 else:
1697 a = copyto[f]
1697 a = copyto[f]
1698 omode = gitmode[man1.flags(a)]
1698 omode = gitmode[man1.flags(a)]
1699 _addmodehdr(header, omode, mode)
1699 _addmodehdr(header, omode, mode)
1700 if a in removed and a not in gone:
1700 if a in removed and a not in gone:
1701 op = 'rename'
1701 op = 'rename'
1702 gone.add(a)
1702 gone.add(a)
1703 else:
1703 else:
1704 op = 'copy'
1704 op = 'copy'
1705 header.append('%s from %s\n' % (op, join(a)))
1705 header.append('%s from %s\n' % (op, join(a)))
1706 header.append('%s to %s\n' % (op, join(f)))
1706 header.append('%s to %s\n' % (op, join(f)))
1707 to = getfilectx(a, ctx1).data()
1707 to = getfilectx(a, ctx1).data()
1708 else:
1708 else:
1709 losedatafn(f)
1709 losedatafn(f)
1710 else:
1710 else:
1711 if opts.git:
1711 if opts.git:
1712 header.append('new file mode %s\n' % mode)
1712 header.append('new file mode %s\n' % mode)
1713 elif ctx2.flags(f):
1713 elif ctx2.flags(f):
1714 losedatafn(f)
1714 losedatafn(f)
1715 # In theory, if tn was copied or renamed we should check
1715 # In theory, if tn was copied or renamed we should check
1716 # if the source is binary too but the copy record already
1716 # if the source is binary too but the copy record already
1717 # forces git mode.
1717 # forces git mode.
1718 if util.binary(tn):
1718 if util.binary(tn):
1719 if opts.git:
1719 if opts.git:
1720 dodiff = 'binary'
1720 dodiff = 'binary'
1721 else:
1721 else:
1722 losedatafn(f)
1722 losedatafn(f)
1723 if not opts.git and not tn:
1723 if not opts.git and not tn:
1724 # regular diffs cannot represent new empty file
1724 # regular diffs cannot represent new empty file
1725 losedatafn(f)
1725 losedatafn(f)
1726 elif f in removed:
1726 elif f in removed:
1727 if opts.git:
1727 if opts.git:
1728 # have we already reported a copy above?
1728 # have we already reported a copy above?
1729 if ((f in copy and copy[f] in added
1729 if ((f in copy and copy[f] in added
1730 and copyto[copy[f]] == f) or
1730 and copyto[copy[f]] == f) or
1731 (f in copyto and copyto[f] in added
1731 (f in copyto and copyto[f] in added
1732 and copy[copyto[f]] == f)):
1732 and copy[copyto[f]] == f)):
1733 dodiff = False
1733 dodiff = False
1734 else:
1734 else:
1735 header.append('deleted file mode %s\n' %
1735 header.append('deleted file mode %s\n' %
1736 gitmode[man1.flags(f)])
1736 gitmode[man1.flags(f)])
1737 elif not to or util.binary(to):
1737 elif not to or util.binary(to):
1738 # regular diffs cannot represent empty file deletion
1738 # regular diffs cannot represent empty file deletion
1739 losedatafn(f)
1739 losedatafn(f)
1740 else:
1740 else:
1741 oflag = man1.flags(f)
1741 oflag = man1.flags(f)
1742 nflag = ctx2.flags(f)
1742 nflag = ctx2.flags(f)
1743 binary = util.binary(to) or util.binary(tn)
1743 binary = util.binary(to) or util.binary(tn)
1744 if opts.git:
1744 if opts.git:
1745 _addmodehdr(header, gitmode[oflag], gitmode[nflag])
1745 _addmodehdr(header, gitmode[oflag], gitmode[nflag])
1746 if binary:
1746 if binary:
1747 dodiff = 'binary'
1747 dodiff = 'binary'
1748 elif binary or nflag != oflag:
1748 elif binary or nflag != oflag:
1749 losedatafn(f)
1749 losedatafn(f)
1750 if opts.git:
1750 if opts.git:
1751 header.insert(0, mdiff.diffline(revs, join(a), join(b), opts))
1751 header.insert(0, mdiff.diffline(revs, join(a), join(b), opts))
1752
1752
1753 if dodiff:
1753 if dodiff:
1754 if dodiff == 'binary':
1754 if dodiff == 'binary':
1755 text = b85diff(to, tn)
1755 text = b85diff(to, tn)
1756 else:
1756 else:
1757 text = mdiff.unidiff(to, date1,
1757 text = mdiff.unidiff(to, date1,
1758 # ctx2 date may be dynamic
1758 # ctx2 date may be dynamic
1759 tn, util.datestr(ctx2.date()),
1759 tn, util.datestr(ctx2.date()),
1760 join(a), join(b), revs, opts=opts)
1760 join(a), join(b), revs, opts=opts)
1761 if header and (text or len(header) > 1):
1761 if header and (text or len(header) > 1):
1762 yield ''.join(header)
1762 yield ''.join(header)
1763 if text:
1763 if text:
1764 yield text
1764 yield text
1765
1765
1766 def diffstatsum(stats):
1766 def diffstatsum(stats):
1767 maxfile, maxtotal, addtotal, removetotal, binary = 0, 0, 0, 0, False
1767 maxfile, maxtotal, addtotal, removetotal, binary = 0, 0, 0, 0, False
1768 for f, a, r, b in stats:
1768 for f, a, r, b in stats:
1769 maxfile = max(maxfile, encoding.colwidth(f))
1769 maxfile = max(maxfile, encoding.colwidth(f))
1770 maxtotal = max(maxtotal, a + r)
1770 maxtotal = max(maxtotal, a + r)
1771 addtotal += a
1771 addtotal += a
1772 removetotal += r
1772 removetotal += r
1773 binary = binary or b
1773 binary = binary or b
1774
1774
1775 return maxfile, maxtotal, addtotal, removetotal, binary
1775 return maxfile, maxtotal, addtotal, removetotal, binary
1776
1776
1777 def diffstatdata(lines):
1777 def diffstatdata(lines):
1778 diffre = re.compile('^diff .*-r [a-z0-9]+\s(.*)$')
1778 diffre = re.compile('^diff .*-r [a-z0-9]+\s(.*)$')
1779
1779
1780 results = []
1780 results = []
1781 filename, adds, removes = None, 0, 0
1781 filename, adds, removes = None, 0, 0
1782
1782
1783 def addresult():
1783 def addresult():
1784 if filename:
1784 if filename:
1785 isbinary = adds == 0 and removes == 0
1785 isbinary = adds == 0 and removes == 0
1786 results.append((filename, adds, removes, isbinary))
1786 results.append((filename, adds, removes, isbinary))
1787
1787
1788 for line in lines:
1788 for line in lines:
1789 if line.startswith('diff'):
1789 if line.startswith('diff'):
1790 addresult()
1790 addresult()
1791 # set numbers to 0 anyway when starting new file
1791 # set numbers to 0 anyway when starting new file
1792 adds, removes = 0, 0
1792 adds, removes = 0, 0
1793 if line.startswith('diff --git'):
1793 if line.startswith('diff --git'):
1794 filename = gitre.search(line).group(1)
1794 filename = gitre.search(line).group(1)
1795 elif line.startswith('diff -r'):
1795 elif line.startswith('diff -r'):
1796 # format: "diff -r ... -r ... filename"
1796 # format: "diff -r ... -r ... filename"
1797 filename = diffre.search(line).group(1)
1797 filename = diffre.search(line).group(1)
1798 elif line.startswith('+') and not line.startswith('+++'):
1798 elif line.startswith('+') and not line.startswith('+++'):
1799 adds += 1
1799 adds += 1
1800 elif line.startswith('-') and not line.startswith('---'):
1800 elif line.startswith('-') and not line.startswith('---'):
1801 removes += 1
1801 removes += 1
1802 addresult()
1802 addresult()
1803 return results
1803 return results
1804
1804
1805 def diffstat(lines, width=80, git=False):
1805 def diffstat(lines, width=80, git=False):
1806 output = []
1806 output = []
1807 stats = diffstatdata(lines)
1807 stats = diffstatdata(lines)
1808 maxname, maxtotal, totaladds, totalremoves, hasbinary = diffstatsum(stats)
1808 maxname, maxtotal, totaladds, totalremoves, hasbinary = diffstatsum(stats)
1809
1809
1810 countwidth = len(str(maxtotal))
1810 countwidth = len(str(maxtotal))
1811 if hasbinary and countwidth < 3:
1811 if hasbinary and countwidth < 3:
1812 countwidth = 3
1812 countwidth = 3
1813 graphwidth = width - countwidth - maxname - 6
1813 graphwidth = width - countwidth - maxname - 6
1814 if graphwidth < 10:
1814 if graphwidth < 10:
1815 graphwidth = 10
1815 graphwidth = 10
1816
1816
1817 def scale(i):
1817 def scale(i):
1818 if maxtotal <= graphwidth:
1818 if maxtotal <= graphwidth:
1819 return i
1819 return i
1820 # If diffstat runs out of room it doesn't print anything,
1820 # If diffstat runs out of room it doesn't print anything,
1821 # which isn't very useful, so always print at least one + or -
1821 # which isn't very useful, so always print at least one + or -
1822 # if there were at least some changes.
1822 # if there were at least some changes.
1823 return max(i * graphwidth // maxtotal, int(bool(i)))
1823 return max(i * graphwidth // maxtotal, int(bool(i)))
1824
1824
1825 for filename, adds, removes, isbinary in stats:
1825 for filename, adds, removes, isbinary in stats:
1826 if git and isbinary:
1826 if git and isbinary:
1827 count = 'Bin'
1827 count = 'Bin'
1828 else:
1828 else:
1829 count = adds + removes
1829 count = adds + removes
1830 pluses = '+' * scale(adds)
1830 pluses = '+' * scale(adds)
1831 minuses = '-' * scale(removes)
1831 minuses = '-' * scale(removes)
1832 output.append(' %s%s | %*s %s%s\n' %
1832 output.append(' %s%s | %*s %s%s\n' %
1833 (filename, ' ' * (maxname - encoding.colwidth(filename)),
1833 (filename, ' ' * (maxname - encoding.colwidth(filename)),
1834 countwidth, count, pluses, minuses))
1834 countwidth, count, pluses, minuses))
1835
1835
1836 if stats:
1836 if stats:
1837 output.append(_(' %d files changed, %d insertions(+), %d deletions(-)\n')
1837 output.append(_(' %d files changed, %d insertions(+), %d deletions(-)\n')
1838 % (len(stats), totaladds, totalremoves))
1838 % (len(stats), totaladds, totalremoves))
1839
1839
1840 return ''.join(output)
1840 return ''.join(output)
1841
1841
1842 def diffstatui(*args, **kw):
1842 def diffstatui(*args, **kw):
1843 '''like diffstat(), but yields 2-tuples of (output, label) for
1843 '''like diffstat(), but yields 2-tuples of (output, label) for
1844 ui.write()
1844 ui.write()
1845 '''
1845 '''
1846
1846
1847 for line in diffstat(*args, **kw).splitlines():
1847 for line in diffstat(*args, **kw).splitlines():
1848 if line and line[-1] in '+-':
1848 if line and line[-1] in '+-':
1849 name, graph = line.rsplit(' ', 1)
1849 name, graph = line.rsplit(' ', 1)
1850 yield (name + ' ', '')
1850 yield (name + ' ', '')
1851 m = re.search(r'\++', graph)
1851 m = re.search(r'\++', graph)
1852 if m:
1852 if m:
1853 yield (m.group(0), 'diffstat.inserted')
1853 yield (m.group(0), 'diffstat.inserted')
1854 m = re.search(r'-+', graph)
1854 m = re.search(r'-+', graph)
1855 if m:
1855 if m:
1856 yield (m.group(0), 'diffstat.deleted')
1856 yield (m.group(0), 'diffstat.deleted')
1857 else:
1857 else:
1858 yield (line, '')
1858 yield (line, '')
1859 yield ('\n', '')
1859 yield ('\n', '')
@@ -1,175 +1,175 b''
1 # test-batching.py - tests for transparent command batching
1 # test-batching.py - tests for transparent command batching
2 #
2 #
3 # Copyright 2011 Peter Arrenbrecht <peter@arrenbrecht.ch>
3 # Copyright 2011 Peter Arrenbrecht <peter@arrenbrecht.ch>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from mercurial.wireproto import localbatch, remotebatch, batchable, future
8 from mercurial.wireproto import localbatch, remotebatch, batchable, future
9
9
10 # equivalent of repo.repository
10 # equivalent of repo.repository
11 class thing(object):
11 class thing(object):
12 def hello(self):
12 def hello(self):
13 return "Ready."
13 return "Ready."
14
14
15 # equivalent of localrepo.localrepository
15 # equivalent of localrepo.localrepository
16 class localthing(thing):
16 class localthing(thing):
17 def foo(self, one, two=None):
17 def foo(self, one, two=None):
18 if one:
18 if one:
19 return "%s and %s" % (one, two,)
19 return "%s and %s" % (one, two,)
20 return "Nope"
20 return "Nope"
21 def bar(self, b, a):
21 def bar(self, b, a):
22 return "%s und %s" % (b, a,)
22 return "%s und %s" % (b, a,)
23 def greet(self, name=None):
23 def greet(self, name=None):
24 return "Hello, %s" % name
24 return "Hello, %s" % name
25 def batch(self):
25 def batch(self):
26 '''Support for local batching.'''
26 '''Support for local batching.'''
27 return localbatch(self)
27 return localbatch(self)
28
28
29 # usage of "thing" interface
29 # usage of "thing" interface
30 def use(it):
30 def use(it):
31
31
32 # Direct call to base method shared between client and server.
32 # Direct call to base method shared between client and server.
33 print it.hello()
33 print it.hello()
34
34
35 # Direct calls to proxied methods. They cause individual roundtrips.
35 # Direct calls to proxied methods. They cause individual roundtrips.
36 print it.foo("Un", two="Deux")
36 print it.foo("Un", two="Deux")
37 print it.bar("Eins", "Zwei")
37 print it.bar("Eins", "Zwei")
38
38
39 # Batched call to a couple of (possibly proxied) methods.
39 # Batched call to a couple of (possibly proxied) methods.
40 batch = it.batch()
40 batch = it.batch()
41 # The calls return futures to eventually hold results.
41 # The calls return futures to eventually hold results.
42 foo = batch.foo(one="One", two="Two")
42 foo = batch.foo(one="One", two="Two")
43 foo2 = batch.foo(None)
43 foo2 = batch.foo(None)
44 bar = batch.bar("Eins", "Zwei")
44 bar = batch.bar("Eins", "Zwei")
45 # We can call non-batchable proxy methods, but the break the current batch
45 # We can call non-batchable proxy methods, but the break the current batch
46 # request and cause additional roundtrips.
46 # request and cause additional roundtrips.
47 greet = batch.greet(name="John Smith")
47 greet = batch.greet(name="John Smith")
48 # We can also add local methods into the mix, but they break the batch too.
48 # We can also add local methods into the mix, but they break the batch too.
49 hello = batch.hello()
49 hello = batch.hello()
50 bar2 = batch.bar(b="Uno", a="Due")
50 bar2 = batch.bar(b="Uno", a="Due")
51 # Only now are all the calls executed in sequence, with as few roundtrips
51 # Only now are all the calls executed in sequence, with as few roundtrips
52 # as possible.
52 # as possible.
53 batch.submit()
53 batch.submit()
54 # After the call to submit, the futures actually contain values.
54 # After the call to submit, the futures actually contain values.
55 print foo.value
55 print foo.value
56 print foo2.value
56 print foo2.value
57 print bar.value
57 print bar.value
58 print greet.value
58 print greet.value
59 print hello.value
59 print hello.value
60 print bar2.value
60 print bar2.value
61
61
62 # local usage
62 # local usage
63 mylocal = localthing()
63 mylocal = localthing()
64 print
64 print
65 print "== Local"
65 print "== Local"
66 use(mylocal)
66 use(mylocal)
67
67
68 # demo remoting; mimicks what wireproto and HTTP/SSH do
68 # demo remoting; mimicks what wireproto and HTTP/SSH do
69
69
70 # shared
70 # shared
71
71
72 def escapearg(plain):
72 def escapearg(plain):
73 return (plain
73 return (plain
74 .replace(':', '::')
74 .replace(':', '::')
75 .replace(',', ':,')
75 .replace(',', ':,')
76 .replace(';', ':;')
76 .replace(';', ':;')
77 .replace('=', ':='))
77 .replace('=', ':='))
78 def unescapearg(escaped):
78 def unescapearg(escaped):
79 return (escaped
79 return (escaped
80 .replace(':=', '=')
80 .replace(':=', '=')
81 .replace(':;', ';')
81 .replace(':;', ';')
82 .replace(':,', ',')
82 .replace(':,', ',')
83 .replace('::', ':'))
83 .replace('::', ':'))
84
84
85 # server side
85 # server side
86
86
87 # equivalent of wireproto's global functions
87 # equivalent of wireproto's global functions
88 class server:
88 class server(object):
89 def __init__(self, local):
89 def __init__(self, local):
90 self.local = local
90 self.local = local
91 def _call(self, name, args):
91 def _call(self, name, args):
92 args = dict(arg.split('=', 1) for arg in args)
92 args = dict(arg.split('=', 1) for arg in args)
93 return getattr(self, name)(**args)
93 return getattr(self, name)(**args)
94 def perform(self, req):
94 def perform(self, req):
95 print "REQ:", req
95 print "REQ:", req
96 name, args = req.split('?', 1)
96 name, args = req.split('?', 1)
97 args = args.split('&')
97 args = args.split('&')
98 vals = dict(arg.split('=', 1) for arg in args)
98 vals = dict(arg.split('=', 1) for arg in args)
99 res = getattr(self, name)(**vals)
99 res = getattr(self, name)(**vals)
100 print " ->", res
100 print " ->", res
101 return res
101 return res
102 def batch(self, cmds):
102 def batch(self, cmds):
103 res = []
103 res = []
104 for pair in cmds.split(';'):
104 for pair in cmds.split(';'):
105 name, args = pair.split(':', 1)
105 name, args = pair.split(':', 1)
106 vals = {}
106 vals = {}
107 for a in args.split(','):
107 for a in args.split(','):
108 if a:
108 if a:
109 n, v = a.split('=')
109 n, v = a.split('=')
110 vals[n] = unescapearg(v)
110 vals[n] = unescapearg(v)
111 res.append(escapearg(getattr(self, name)(**vals)))
111 res.append(escapearg(getattr(self, name)(**vals)))
112 return ';'.join(res)
112 return ';'.join(res)
113 def foo(self, one, two):
113 def foo(self, one, two):
114 return mangle(self.local.foo(unmangle(one), unmangle(two)))
114 return mangle(self.local.foo(unmangle(one), unmangle(two)))
115 def bar(self, b, a):
115 def bar(self, b, a):
116 return mangle(self.local.bar(unmangle(b), unmangle(a)))
116 return mangle(self.local.bar(unmangle(b), unmangle(a)))
117 def greet(self, name):
117 def greet(self, name):
118 return mangle(self.local.greet(unmangle(name)))
118 return mangle(self.local.greet(unmangle(name)))
119 myserver = server(mylocal)
119 myserver = server(mylocal)
120
120
121 # local side
121 # local side
122
122
123 # equivalent of wireproto.encode/decodelist, that is, type-specific marshalling
123 # equivalent of wireproto.encode/decodelist, that is, type-specific marshalling
124 # here we just transform the strings a bit to check we're properly en-/decoding
124 # here we just transform the strings a bit to check we're properly en-/decoding
125 def mangle(s):
125 def mangle(s):
126 return ''.join(chr(ord(c) + 1) for c in s)
126 return ''.join(chr(ord(c) + 1) for c in s)
127 def unmangle(s):
127 def unmangle(s):
128 return ''.join(chr(ord(c) - 1) for c in s)
128 return ''.join(chr(ord(c) - 1) for c in s)
129
129
130 # equivalent of wireproto.wirerepository and something like http's wire format
130 # equivalent of wireproto.wirerepository and something like http's wire format
131 class remotething(thing):
131 class remotething(thing):
132 def __init__(self, server):
132 def __init__(self, server):
133 self.server = server
133 self.server = server
134 def _submitone(self, name, args):
134 def _submitone(self, name, args):
135 req = name + '?' + '&'.join(['%s=%s' % (n, v) for n, v in args])
135 req = name + '?' + '&'.join(['%s=%s' % (n, v) for n, v in args])
136 return self.server.perform(req)
136 return self.server.perform(req)
137 def _submitbatch(self, cmds):
137 def _submitbatch(self, cmds):
138 req = []
138 req = []
139 for name, args in cmds:
139 for name, args in cmds:
140 args = ','.join(n + '=' + escapearg(v) for n, v in args)
140 args = ','.join(n + '=' + escapearg(v) for n, v in args)
141 req.append(name + ':' + args)
141 req.append(name + ':' + args)
142 req = ';'.join(req)
142 req = ';'.join(req)
143 res = self._submitone('batch', [('cmds', req,)])
143 res = self._submitone('batch', [('cmds', req,)])
144 return res.split(';')
144 return res.split(';')
145
145
146 def batch(self):
146 def batch(self):
147 return remotebatch(self)
147 return remotebatch(self)
148
148
149 @batchable
149 @batchable
150 def foo(self, one, two=None):
150 def foo(self, one, two=None):
151 if not one:
151 if not one:
152 yield "Nope", None
152 yield "Nope", None
153 encargs = [('one', mangle(one),), ('two', mangle(two),)]
153 encargs = [('one', mangle(one),), ('two', mangle(two),)]
154 encresref = future()
154 encresref = future()
155 yield encargs, encresref
155 yield encargs, encresref
156 yield unmangle(encresref.value)
156 yield unmangle(encresref.value)
157
157
158 @batchable
158 @batchable
159 def bar(self, b, a):
159 def bar(self, b, a):
160 encresref = future()
160 encresref = future()
161 yield [('b', mangle(b),), ('a', mangle(a),)], encresref
161 yield [('b', mangle(b),), ('a', mangle(a),)], encresref
162 yield unmangle(encresref.value)
162 yield unmangle(encresref.value)
163
163
164 # greet is coded directly. It therefore does not support batching. If it
164 # greet is coded directly. It therefore does not support batching. If it
165 # does appear in a batch, the batch is split around greet, and the call to
165 # does appear in a batch, the batch is split around greet, and the call to
166 # greet is done in its own roundtrip.
166 # greet is done in its own roundtrip.
167 def greet(self, name=None):
167 def greet(self, name=None):
168 return unmangle(self._submitone('greet', [('name', mangle(name),)]))
168 return unmangle(self._submitone('greet', [('name', mangle(name),)]))
169
169
170 # demo remote usage
170 # demo remote usage
171
171
172 myproxy = remotething(myserver)
172 myproxy = remotething(myserver)
173 print
173 print
174 print "== Remote"
174 print "== Remote"
175 use(myproxy)
175 use(myproxy)
@@ -1,45 +1,45 b''
1 from mercurial import wireproto
1 from mercurial import wireproto
2
2
3 class proto():
3 class proto(object):
4 def __init__(self, args):
4 def __init__(self, args):
5 self.args = args
5 self.args = args
6 def getargs(self, spec):
6 def getargs(self, spec):
7 args = self.args
7 args = self.args
8 args.setdefault('*', {})
8 args.setdefault('*', {})
9 names = spec.split()
9 names = spec.split()
10 return [args[n] for n in names]
10 return [args[n] for n in names]
11
11
12 class clientrepo(wireproto.wirerepository):
12 class clientrepo(wireproto.wirerepository):
13 def __init__(self, serverrepo):
13 def __init__(self, serverrepo):
14 self.serverrepo = serverrepo
14 self.serverrepo = serverrepo
15 def _call(self, cmd, **args):
15 def _call(self, cmd, **args):
16 return wireproto.dispatch(self.serverrepo, proto(args), cmd)
16 return wireproto.dispatch(self.serverrepo, proto(args), cmd)
17
17
18 @wireproto.batchable
18 @wireproto.batchable
19 def greet(self, name):
19 def greet(self, name):
20 f = wireproto.future()
20 f = wireproto.future()
21 yield wireproto.todict(name=mangle(name)), f
21 yield wireproto.todict(name=mangle(name)), f
22 yield unmangle(f.value)
22 yield unmangle(f.value)
23
23
24 class serverrepo():
24 class serverrepo(object):
25 def greet(self, name):
25 def greet(self, name):
26 return "Hello, " + name
26 return "Hello, " + name
27
27
28 def mangle(s):
28 def mangle(s):
29 return ''.join(chr(ord(c) + 1) for c in s)
29 return ''.join(chr(ord(c) + 1) for c in s)
30 def unmangle(s):
30 def unmangle(s):
31 return ''.join(chr(ord(c) - 1) for c in s)
31 return ''.join(chr(ord(c) - 1) for c in s)
32
32
33 def greet(repo, proto, name):
33 def greet(repo, proto, name):
34 return mangle(repo.greet(unmangle(name)))
34 return mangle(repo.greet(unmangle(name)))
35
35
36 wireproto.commands['greet'] = (greet, 'name',)
36 wireproto.commands['greet'] = (greet, 'name',)
37
37
38 srv = serverrepo()
38 srv = serverrepo()
39 clt = clientrepo(srv)
39 clt = clientrepo(srv)
40
40
41 print clt.greet("Foobar")
41 print clt.greet("Foobar")
42 b = clt.batch()
42 b = clt.batch()
43 fs = [b.greet(s) for s in ["Fo, =;o", "Bar"]]
43 fs = [b.greet(s) for s in ["Fo, =;o", "Bar"]]
44 b.submit()
44 b.submit()
45 print [f.value for f in fs]
45 print [f.value for f in fs]
General Comments 0
You need to be logged in to leave comments. Login now