##// END OF EJS Templates
delete some dead comments and docstrings
Mads Kiilerich -
r17426:9724f8f8 default
parent child Browse files
Show More
@@ -1,360 +1,359
1 # monotone.py - monotone support for the convert extension
1 # monotone.py - monotone support for the convert extension
2 #
2 #
3 # Copyright 2008, 2009 Mikkel Fahnoe Jorgensen <mikkel@dvide.com> and
3 # Copyright 2008, 2009 Mikkel Fahnoe Jorgensen <mikkel@dvide.com> and
4 # others
4 # others
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 import os, re
9 import os, re
10 from mercurial import util
10 from mercurial import util
11 from common import NoRepo, commit, converter_source, checktool
11 from common import NoRepo, commit, converter_source, checktool
12 from common import commandline
12 from common import commandline
13 from mercurial.i18n import _
13 from mercurial.i18n import _
14
14
15 class monotone_source(converter_source, commandline):
15 class monotone_source(converter_source, commandline):
16 def __init__(self, ui, path=None, rev=None):
16 def __init__(self, ui, path=None, rev=None):
17 converter_source.__init__(self, ui, path, rev)
17 converter_source.__init__(self, ui, path, rev)
18 commandline.__init__(self, ui, 'mtn')
18 commandline.__init__(self, ui, 'mtn')
19
19
20 self.ui = ui
20 self.ui = ui
21 self.path = path
21 self.path = path
22 self.automatestdio = False
22 self.automatestdio = False
23 self.rev = rev
23 self.rev = rev
24
24
25 norepo = NoRepo(_("%s does not look like a monotone repository")
25 norepo = NoRepo(_("%s does not look like a monotone repository")
26 % path)
26 % path)
27 if not os.path.exists(os.path.join(path, '_MTN')):
27 if not os.path.exists(os.path.join(path, '_MTN')):
28 # Could be a monotone repository (SQLite db file)
28 # Could be a monotone repository (SQLite db file)
29 try:
29 try:
30 f = file(path, 'rb')
30 f = file(path, 'rb')
31 header = f.read(16)
31 header = f.read(16)
32 f.close()
32 f.close()
33 except IOError:
33 except IOError:
34 header = ''
34 header = ''
35 if header != 'SQLite format 3\x00':
35 if header != 'SQLite format 3\x00':
36 raise norepo
36 raise norepo
37
37
38 # regular expressions for parsing monotone output
38 # regular expressions for parsing monotone output
39 space = r'\s*'
39 space = r'\s*'
40 name = r'\s+"((?:\\"|[^"])*)"\s*'
40 name = r'\s+"((?:\\"|[^"])*)"\s*'
41 value = name
41 value = name
42 revision = r'\s+\[(\w+)\]\s*'
42 revision = r'\s+\[(\w+)\]\s*'
43 lines = r'(?:.|\n)+'
43 lines = r'(?:.|\n)+'
44
44
45 self.dir_re = re.compile(space + "dir" + name)
45 self.dir_re = re.compile(space + "dir" + name)
46 self.file_re = re.compile(space + "file" + name +
46 self.file_re = re.compile(space + "file" + name +
47 "content" + revision)
47 "content" + revision)
48 self.add_file_re = re.compile(space + "add_file" + name +
48 self.add_file_re = re.compile(space + "add_file" + name +
49 "content" + revision)
49 "content" + revision)
50 self.patch_re = re.compile(space + "patch" + name +
50 self.patch_re = re.compile(space + "patch" + name +
51 "from" + revision + "to" + revision)
51 "from" + revision + "to" + revision)
52 self.rename_re = re.compile(space + "rename" + name + "to" + name)
52 self.rename_re = re.compile(space + "rename" + name + "to" + name)
53 self.delete_re = re.compile(space + "delete" + name)
53 self.delete_re = re.compile(space + "delete" + name)
54 self.tag_re = re.compile(space + "tag" + name + "revision" +
54 self.tag_re = re.compile(space + "tag" + name + "revision" +
55 revision)
55 revision)
56 self.cert_re = re.compile(lines + space + "name" + name +
56 self.cert_re = re.compile(lines + space + "name" + name +
57 "value" + value)
57 "value" + value)
58
58
59 attr = space + "file" + lines + space + "attr" + space
59 attr = space + "file" + lines + space + "attr" + space
60 self.attr_execute_re = re.compile(attr + '"mtn:execute"' +
60 self.attr_execute_re = re.compile(attr + '"mtn:execute"' +
61 space + '"true"')
61 space + '"true"')
62
62
63 # cached data
63 # cached data
64 self.manifest_rev = None
64 self.manifest_rev = None
65 self.manifest = None
65 self.manifest = None
66 self.files = None
66 self.files = None
67 self.dirs = None
67 self.dirs = None
68
68
69 checktool('mtn', abort=False)
69 checktool('mtn', abort=False)
70
70
71 def mtnrun(self, *args, **kwargs):
71 def mtnrun(self, *args, **kwargs):
72 if self.automatestdio:
72 if self.automatestdio:
73 return self.mtnrunstdio(*args, **kwargs)
73 return self.mtnrunstdio(*args, **kwargs)
74 else:
74 else:
75 return self.mtnrunsingle(*args, **kwargs)
75 return self.mtnrunsingle(*args, **kwargs)
76
76
77 def mtnrunsingle(self, *args, **kwargs):
77 def mtnrunsingle(self, *args, **kwargs):
78 kwargs['d'] = self.path
78 kwargs['d'] = self.path
79 return self.run0('automate', *args, **kwargs)
79 return self.run0('automate', *args, **kwargs)
80
80
81 def mtnrunstdio(self, *args, **kwargs):
81 def mtnrunstdio(self, *args, **kwargs):
82 # Prepare the command in automate stdio format
82 # Prepare the command in automate stdio format
83 command = []
83 command = []
84 for k, v in kwargs.iteritems():
84 for k, v in kwargs.iteritems():
85 command.append("%s:%s" % (len(k), k))
85 command.append("%s:%s" % (len(k), k))
86 if v:
86 if v:
87 command.append("%s:%s" % (len(v), v))
87 command.append("%s:%s" % (len(v), v))
88 if command:
88 if command:
89 command.insert(0, 'o')
89 command.insert(0, 'o')
90 command.append('e')
90 command.append('e')
91
91
92 command.append('l')
92 command.append('l')
93 for arg in args:
93 for arg in args:
94 command += "%s:%s" % (len(arg), arg)
94 command += "%s:%s" % (len(arg), arg)
95 command.append('e')
95 command.append('e')
96 command = ''.join(command)
96 command = ''.join(command)
97
97
98 self.ui.debug("mtn: sending '%s'\n" % command)
98 self.ui.debug("mtn: sending '%s'\n" % command)
99 self.mtnwritefp.write(command)
99 self.mtnwritefp.write(command)
100 self.mtnwritefp.flush()
100 self.mtnwritefp.flush()
101
101
102 return self.mtnstdioreadcommandoutput(command)
102 return self.mtnstdioreadcommandoutput(command)
103
103
104 def mtnstdioreadpacket(self):
104 def mtnstdioreadpacket(self):
105 read = None
105 read = None
106 commandnbr = ''
106 commandnbr = ''
107 while read != ':':
107 while read != ':':
108 read = self.mtnreadfp.read(1)
108 read = self.mtnreadfp.read(1)
109 if not read:
109 if not read:
110 raise util.Abort(_('bad mtn packet - no end of commandnbr'))
110 raise util.Abort(_('bad mtn packet - no end of commandnbr'))
111 commandnbr += read
111 commandnbr += read
112 commandnbr = commandnbr[:-1]
112 commandnbr = commandnbr[:-1]
113
113
114 stream = self.mtnreadfp.read(1)
114 stream = self.mtnreadfp.read(1)
115 if stream not in 'mewptl':
115 if stream not in 'mewptl':
116 raise util.Abort(_('bad mtn packet - bad stream type %s') % stream)
116 raise util.Abort(_('bad mtn packet - bad stream type %s') % stream)
117
117
118 read = self.mtnreadfp.read(1)
118 read = self.mtnreadfp.read(1)
119 if read != ':':
119 if read != ':':
120 raise util.Abort(_('bad mtn packet - no divider before size'))
120 raise util.Abort(_('bad mtn packet - no divider before size'))
121
121
122 read = None
122 read = None
123 lengthstr = ''
123 lengthstr = ''
124 while read != ':':
124 while read != ':':
125 read = self.mtnreadfp.read(1)
125 read = self.mtnreadfp.read(1)
126 if not read:
126 if not read:
127 raise util.Abort(_('bad mtn packet - no end of packet size'))
127 raise util.Abort(_('bad mtn packet - no end of packet size'))
128 lengthstr += read
128 lengthstr += read
129 try:
129 try:
130 length = long(lengthstr[:-1])
130 length = long(lengthstr[:-1])
131 except TypeError:
131 except TypeError:
132 raise util.Abort(_('bad mtn packet - bad packet size %s')
132 raise util.Abort(_('bad mtn packet - bad packet size %s')
133 % lengthstr)
133 % lengthstr)
134
134
135 read = self.mtnreadfp.read(length)
135 read = self.mtnreadfp.read(length)
136 if len(read) != length:
136 if len(read) != length:
137 raise util.Abort(_("bad mtn packet - unable to read full packet "
137 raise util.Abort(_("bad mtn packet - unable to read full packet "
138 "read %s of %s") % (len(read), length))
138 "read %s of %s") % (len(read), length))
139
139
140 return (commandnbr, stream, length, read)
140 return (commandnbr, stream, length, read)
141
141
142 def mtnstdioreadcommandoutput(self, command):
142 def mtnstdioreadcommandoutput(self, command):
143 retval = []
143 retval = []
144 while True:
144 while True:
145 commandnbr, stream, length, output = self.mtnstdioreadpacket()
145 commandnbr, stream, length, output = self.mtnstdioreadpacket()
146 self.ui.debug('mtn: read packet %s:%s:%s\n' %
146 self.ui.debug('mtn: read packet %s:%s:%s\n' %
147 (commandnbr, stream, length))
147 (commandnbr, stream, length))
148
148
149 if stream == 'l':
149 if stream == 'l':
150 # End of command
150 # End of command
151 if output != '0':
151 if output != '0':
152 raise util.Abort(_("mtn command '%s' returned %s") %
152 raise util.Abort(_("mtn command '%s' returned %s") %
153 (command, output))
153 (command, output))
154 break
154 break
155 elif stream in 'ew':
155 elif stream in 'ew':
156 # Error, warning output
156 # Error, warning output
157 self.ui.warn(_('%s error:\n') % self.command)
157 self.ui.warn(_('%s error:\n') % self.command)
158 self.ui.warn(output)
158 self.ui.warn(output)
159 elif stream == 'p':
159 elif stream == 'p':
160 # Progress messages
160 # Progress messages
161 self.ui.debug('mtn: ' + output)
161 self.ui.debug('mtn: ' + output)
162 elif stream == 'm':
162 elif stream == 'm':
163 # Main stream - command output
163 # Main stream - command output
164 retval.append(output)
164 retval.append(output)
165
165
166 return ''.join(retval)
166 return ''.join(retval)
167
167
168 def mtnloadmanifest(self, rev):
168 def mtnloadmanifest(self, rev):
169 if self.manifest_rev == rev:
169 if self.manifest_rev == rev:
170 return
170 return
171 self.manifest = self.mtnrun("get_manifest_of", rev).split("\n\n")
171 self.manifest = self.mtnrun("get_manifest_of", rev).split("\n\n")
172 self.manifest_rev = rev
172 self.manifest_rev = rev
173 self.files = {}
173 self.files = {}
174 self.dirs = {}
174 self.dirs = {}
175
175
176 for e in self.manifest:
176 for e in self.manifest:
177 m = self.file_re.match(e)
177 m = self.file_re.match(e)
178 if m:
178 if m:
179 attr = ""
179 attr = ""
180 name = m.group(1)
180 name = m.group(1)
181 node = m.group(2)
181 node = m.group(2)
182 if self.attr_execute_re.match(e):
182 if self.attr_execute_re.match(e):
183 attr += "x"
183 attr += "x"
184 self.files[name] = (node, attr)
184 self.files[name] = (node, attr)
185 m = self.dir_re.match(e)
185 m = self.dir_re.match(e)
186 if m:
186 if m:
187 self.dirs[m.group(1)] = True
187 self.dirs[m.group(1)] = True
188
188
189 def mtnisfile(self, name, rev):
189 def mtnisfile(self, name, rev):
190 # a non-file could be a directory or a deleted or renamed file
190 # a non-file could be a directory or a deleted or renamed file
191 self.mtnloadmanifest(rev)
191 self.mtnloadmanifest(rev)
192 return name in self.files
192 return name in self.files
193
193
194 def mtnisdir(self, name, rev):
194 def mtnisdir(self, name, rev):
195 self.mtnloadmanifest(rev)
195 self.mtnloadmanifest(rev)
196 return name in self.dirs
196 return name in self.dirs
197
197
198 def mtngetcerts(self, rev):
198 def mtngetcerts(self, rev):
199 certs = {"author":"<missing>", "date":"<missing>",
199 certs = {"author":"<missing>", "date":"<missing>",
200 "changelog":"<missing>", "branch":"<missing>"}
200 "changelog":"<missing>", "branch":"<missing>"}
201 certlist = self.mtnrun("certs", rev)
201 certlist = self.mtnrun("certs", rev)
202 # mtn < 0.45:
202 # mtn < 0.45:
203 # key "test@selenic.com"
203 # key "test@selenic.com"
204 # mtn >= 0.45:
204 # mtn >= 0.45:
205 # key [ff58a7ffb771907c4ff68995eada1c4da068d328]
205 # key [ff58a7ffb771907c4ff68995eada1c4da068d328]
206 certlist = re.split('\n\n key ["\[]', certlist)
206 certlist = re.split('\n\n key ["\[]', certlist)
207 for e in certlist:
207 for e in certlist:
208 m = self.cert_re.match(e)
208 m = self.cert_re.match(e)
209 if m:
209 if m:
210 name, value = m.groups()
210 name, value = m.groups()
211 value = value.replace(r'\"', '"')
211 value = value.replace(r'\"', '"')
212 value = value.replace(r'\\', '\\')
212 value = value.replace(r'\\', '\\')
213 certs[name] = value
213 certs[name] = value
214 # Monotone may have subsecond dates: 2005-02-05T09:39:12.364306
214 # Monotone may have subsecond dates: 2005-02-05T09:39:12.364306
215 # and all times are stored in UTC
215 # and all times are stored in UTC
216 certs["date"] = certs["date"].split('.')[0] + " UTC"
216 certs["date"] = certs["date"].split('.')[0] + " UTC"
217 return certs
217 return certs
218
218
219 # implement the converter_source interface:
219 # implement the converter_source interface:
220
220
221 def getheads(self):
221 def getheads(self):
222 if not self.rev:
222 if not self.rev:
223 return self.mtnrun("leaves").splitlines()
223 return self.mtnrun("leaves").splitlines()
224 else:
224 else:
225 return [self.rev]
225 return [self.rev]
226
226
227 def getchanges(self, rev):
227 def getchanges(self, rev):
228 #revision = self.mtncmd("get_revision %s" % rev).split("\n\n")
229 revision = self.mtnrun("get_revision", rev).split("\n\n")
228 revision = self.mtnrun("get_revision", rev).split("\n\n")
230 files = {}
229 files = {}
231 ignoremove = {}
230 ignoremove = {}
232 renameddirs = []
231 renameddirs = []
233 copies = {}
232 copies = {}
234 for e in revision:
233 for e in revision:
235 m = self.add_file_re.match(e)
234 m = self.add_file_re.match(e)
236 if m:
235 if m:
237 files[m.group(1)] = rev
236 files[m.group(1)] = rev
238 ignoremove[m.group(1)] = rev
237 ignoremove[m.group(1)] = rev
239 m = self.patch_re.match(e)
238 m = self.patch_re.match(e)
240 if m:
239 if m:
241 files[m.group(1)] = rev
240 files[m.group(1)] = rev
242 # Delete/rename is handled later when the convert engine
241 # Delete/rename is handled later when the convert engine
243 # discovers an IOError exception from getfile,
242 # discovers an IOError exception from getfile,
244 # but only if we add the "from" file to the list of changes.
243 # but only if we add the "from" file to the list of changes.
245 m = self.delete_re.match(e)
244 m = self.delete_re.match(e)
246 if m:
245 if m:
247 files[m.group(1)] = rev
246 files[m.group(1)] = rev
248 m = self.rename_re.match(e)
247 m = self.rename_re.match(e)
249 if m:
248 if m:
250 toname = m.group(2)
249 toname = m.group(2)
251 fromname = m.group(1)
250 fromname = m.group(1)
252 if self.mtnisfile(toname, rev):
251 if self.mtnisfile(toname, rev):
253 ignoremove[toname] = 1
252 ignoremove[toname] = 1
254 copies[toname] = fromname
253 copies[toname] = fromname
255 files[toname] = rev
254 files[toname] = rev
256 files[fromname] = rev
255 files[fromname] = rev
257 elif self.mtnisdir(toname, rev):
256 elif self.mtnisdir(toname, rev):
258 renameddirs.append((fromname, toname))
257 renameddirs.append((fromname, toname))
259
258
260 # Directory renames can be handled only once we have recorded
259 # Directory renames can be handled only once we have recorded
261 # all new files
260 # all new files
262 for fromdir, todir in renameddirs:
261 for fromdir, todir in renameddirs:
263 renamed = {}
262 renamed = {}
264 for tofile in self.files:
263 for tofile in self.files:
265 if tofile in ignoremove:
264 if tofile in ignoremove:
266 continue
265 continue
267 if tofile.startswith(todir + '/'):
266 if tofile.startswith(todir + '/'):
268 renamed[tofile] = fromdir + tofile[len(todir):]
267 renamed[tofile] = fromdir + tofile[len(todir):]
269 # Avoid chained moves like:
268 # Avoid chained moves like:
270 # d1(/a) => d3/d1(/a)
269 # d1(/a) => d3/d1(/a)
271 # d2 => d3
270 # d2 => d3
272 ignoremove[tofile] = 1
271 ignoremove[tofile] = 1
273 for tofile, fromfile in renamed.items():
272 for tofile, fromfile in renamed.items():
274 self.ui.debug (_("copying file in renamed directory "
273 self.ui.debug (_("copying file in renamed directory "
275 "from '%s' to '%s'")
274 "from '%s' to '%s'")
276 % (fromfile, tofile), '\n')
275 % (fromfile, tofile), '\n')
277 files[tofile] = rev
276 files[tofile] = rev
278 copies[tofile] = fromfile
277 copies[tofile] = fromfile
279 for fromfile in renamed.values():
278 for fromfile in renamed.values():
280 files[fromfile] = rev
279 files[fromfile] = rev
281
280
282 return (files.items(), copies)
281 return (files.items(), copies)
283
282
284 def getfile(self, name, rev):
283 def getfile(self, name, rev):
285 if not self.mtnisfile(name, rev):
284 if not self.mtnisfile(name, rev):
286 raise IOError # file was deleted or renamed
285 raise IOError # file was deleted or renamed
287 try:
286 try:
288 data = self.mtnrun("get_file_of", name, r=rev)
287 data = self.mtnrun("get_file_of", name, r=rev)
289 except Exception:
288 except Exception:
290 raise IOError # file was deleted or renamed
289 raise IOError # file was deleted or renamed
291 self.mtnloadmanifest(rev)
290 self.mtnloadmanifest(rev)
292 node, attr = self.files.get(name, (None, ""))
291 node, attr = self.files.get(name, (None, ""))
293 return data, attr
292 return data, attr
294
293
295 def getcommit(self, rev):
294 def getcommit(self, rev):
296 extra = {}
295 extra = {}
297 certs = self.mtngetcerts(rev)
296 certs = self.mtngetcerts(rev)
298 if certs.get('suspend') == certs["branch"]:
297 if certs.get('suspend') == certs["branch"]:
299 extra['close'] = '1'
298 extra['close'] = '1'
300 return commit(
299 return commit(
301 author=certs["author"],
300 author=certs["author"],
302 date=util.datestr(util.strdate(certs["date"], "%Y-%m-%dT%H:%M:%S")),
301 date=util.datestr(util.strdate(certs["date"], "%Y-%m-%dT%H:%M:%S")),
303 desc=certs["changelog"],
302 desc=certs["changelog"],
304 rev=rev,
303 rev=rev,
305 parents=self.mtnrun("parents", rev).splitlines(),
304 parents=self.mtnrun("parents", rev).splitlines(),
306 branch=certs["branch"],
305 branch=certs["branch"],
307 extra=extra)
306 extra=extra)
308
307
309 def gettags(self):
308 def gettags(self):
310 tags = {}
309 tags = {}
311 for e in self.mtnrun("tags").split("\n\n"):
310 for e in self.mtnrun("tags").split("\n\n"):
312 m = self.tag_re.match(e)
311 m = self.tag_re.match(e)
313 if m:
312 if m:
314 tags[m.group(1)] = m.group(2)
313 tags[m.group(1)] = m.group(2)
315 return tags
314 return tags
316
315
317 def getchangedfiles(self, rev, i):
316 def getchangedfiles(self, rev, i):
318 # This function is only needed to support --filemap
317 # This function is only needed to support --filemap
319 # ... and we don't support that
318 # ... and we don't support that
320 raise NotImplementedError
319 raise NotImplementedError
321
320
322 def before(self):
321 def before(self):
323 # Check if we have a new enough version to use automate stdio
322 # Check if we have a new enough version to use automate stdio
324 version = 0.0
323 version = 0.0
325 try:
324 try:
326 versionstr = self.mtnrunsingle("interface_version")
325 versionstr = self.mtnrunsingle("interface_version")
327 version = float(versionstr)
326 version = float(versionstr)
328 except Exception:
327 except Exception:
329 raise util.Abort(_("unable to determine mtn automate interface "
328 raise util.Abort(_("unable to determine mtn automate interface "
330 "version"))
329 "version"))
331
330
332 if version >= 12.0:
331 if version >= 12.0:
333 self.automatestdio = True
332 self.automatestdio = True
334 self.ui.debug("mtn automate version %s - using automate stdio\n" %
333 self.ui.debug("mtn automate version %s - using automate stdio\n" %
335 version)
334 version)
336
335
337 # launch the long-running automate stdio process
336 # launch the long-running automate stdio process
338 self.mtnwritefp, self.mtnreadfp = self._run2('automate', 'stdio',
337 self.mtnwritefp, self.mtnreadfp = self._run2('automate', 'stdio',
339 '-d', self.path)
338 '-d', self.path)
340 # read the headers
339 # read the headers
341 read = self.mtnreadfp.readline()
340 read = self.mtnreadfp.readline()
342 if read != 'format-version: 2\n':
341 if read != 'format-version: 2\n':
343 raise util.Abort(_('mtn automate stdio header unexpected: %s')
342 raise util.Abort(_('mtn automate stdio header unexpected: %s')
344 % read)
343 % read)
345 while read != '\n':
344 while read != '\n':
346 read = self.mtnreadfp.readline()
345 read = self.mtnreadfp.readline()
347 if not read:
346 if not read:
348 raise util.Abort(_("failed to reach end of mtn automate "
347 raise util.Abort(_("failed to reach end of mtn automate "
349 "stdio headers"))
348 "stdio headers"))
350 else:
349 else:
351 self.ui.debug("mtn automate version %s - not using automate stdio "
350 self.ui.debug("mtn automate version %s - not using automate stdio "
352 "(automate >= 12.0 - mtn >= 0.46 is needed)\n" % version)
351 "(automate >= 12.0 - mtn >= 0.46 is needed)\n" % version)
353
352
354 def after(self):
353 def after(self):
355 if self.automatestdio:
354 if self.automatestdio:
356 self.mtnwritefp.close()
355 self.mtnwritefp.close()
357 self.mtnwritefp = None
356 self.mtnwritefp = None
358 self.mtnreadfp.close()
357 self.mtnreadfp.close()
359 self.mtnreadfp = None
358 self.mtnreadfp = None
360
359
@@ -1,1582 +1,1581
1 """ Multicast DNS Service Discovery for Python, v0.12
1 """ Multicast DNS Service Discovery for Python, v0.12
2 Copyright (C) 2003, Paul Scott-Murphy
2 Copyright (C) 2003, Paul Scott-Murphy
3
3
4 This module provides a framework for the use of DNS Service Discovery
4 This module provides a framework for the use of DNS Service Discovery
5 using IP multicast. It has been tested against the JRendezvous
5 using IP multicast. It has been tested against the JRendezvous
6 implementation from <a href="http://strangeberry.com">StrangeBerry</a>,
6 implementation from <a href="http://strangeberry.com">StrangeBerry</a>,
7 and against the mDNSResponder from Mac OS X 10.3.8.
7 and against the mDNSResponder from Mac OS X 10.3.8.
8
8
9 This library is free software; you can redistribute it and/or
9 This library is free software; you can redistribute it and/or
10 modify it under the terms of the GNU Lesser General Public
10 modify it under the terms of the GNU Lesser General Public
11 License as published by the Free Software Foundation; either
11 License as published by the Free Software Foundation; either
12 version 2.1 of the License, or (at your option) any later version.
12 version 2.1 of the License, or (at your option) any later version.
13
13
14 This library is distributed in the hope that it will be useful,
14 This library is distributed in the hope that it will be useful,
15 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 but WITHOUT ANY WARRANTY; without even the implied warranty of
16 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 Lesser General Public License for more details.
17 Lesser General Public License for more details.
18
18
19 You should have received a copy of the GNU Lesser General Public
19 You should have received a copy of the GNU Lesser General Public
20 License along with this library; if not, see
20 License along with this library; if not, see
21 <http://www.gnu.org/licenses/>.
21 <http://www.gnu.org/licenses/>.
22
22
23 """
23 """
24
24
25 """0.12 update - allow selection of binding interface
25 """0.12 update - allow selection of binding interface
26 typo fix - Thanks A. M. Kuchlingi
26 typo fix - Thanks A. M. Kuchlingi
27 removed all use of word 'Rendezvous' - this is an API change"""
27 removed all use of word 'Rendezvous' - this is an API change"""
28
28
29 """0.11 update - correction to comments for addListener method
29 """0.11 update - correction to comments for addListener method
30 support for new record types seen from OS X
30 support for new record types seen from OS X
31 - IPv6 address
31 - IPv6 address
32 - hostinfo
32 - hostinfo
33 ignore unknown DNS record types
33 ignore unknown DNS record types
34 fixes to name decoding
34 fixes to name decoding
35 works alongside other processes using port 5353 (e.g. on Mac OS X)
35 works alongside other processes using port 5353 (e.g. on Mac OS X)
36 tested against Mac OS X 10.3.2's mDNSResponder
36 tested against Mac OS X 10.3.2's mDNSResponder
37 corrections to removal of list entries for service browser"""
37 corrections to removal of list entries for service browser"""
38
38
39 """0.10 update - Jonathon Paisley contributed these corrections:
39 """0.10 update - Jonathon Paisley contributed these corrections:
40 always multicast replies, even when query is unicast
40 always multicast replies, even when query is unicast
41 correct a pointer encoding problem
41 correct a pointer encoding problem
42 can now write records in any order
42 can now write records in any order
43 traceback shown on failure
43 traceback shown on failure
44 better TXT record parsing
44 better TXT record parsing
45 server is now separate from name
45 server is now separate from name
46 can cancel a service browser
46 can cancel a service browser
47
47
48 modified some unit tests to accommodate these changes"""
48 modified some unit tests to accommodate these changes"""
49
49
50 """0.09 update - remove all records on service unregistration
50 """0.09 update - remove all records on service unregistration
51 fix DOS security problem with readName"""
51 fix DOS security problem with readName"""
52
52
53 """0.08 update - changed licensing to LGPL"""
53 """0.08 update - changed licensing to LGPL"""
54
54
55 """0.07 update - faster shutdown on engine
55 """0.07 update - faster shutdown on engine
56 pointer encoding of outgoing names
56 pointer encoding of outgoing names
57 ServiceBrowser now works
57 ServiceBrowser now works
58 new unit tests"""
58 new unit tests"""
59
59
60 """0.06 update - small improvements with unit tests
60 """0.06 update - small improvements with unit tests
61 added defined exception types
61 added defined exception types
62 new style objects
62 new style objects
63 fixed hostname/interface problem
63 fixed hostname/interface problem
64 fixed socket timeout problem
64 fixed socket timeout problem
65 fixed addServiceListener() typo bug
65 fixed addServiceListener() typo bug
66 using select() for socket reads
66 using select() for socket reads
67 tested on Debian unstable with Python 2.2.2"""
67 tested on Debian unstable with Python 2.2.2"""
68
68
69 """0.05 update - ensure case insensitivity on domain names
69 """0.05 update - ensure case insensitivity on domain names
70 support for unicast DNS queries"""
70 support for unicast DNS queries"""
71
71
72 """0.04 update - added some unit tests
72 """0.04 update - added some unit tests
73 added __ne__ adjuncts where required
73 added __ne__ adjuncts where required
74 ensure names end in '.local.'
74 ensure names end in '.local.'
75 timeout on receiving socket for clean shutdown"""
75 timeout on receiving socket for clean shutdown"""
76
76
77 __author__ = "Paul Scott-Murphy"
77 __author__ = "Paul Scott-Murphy"
78 __email__ = "paul at scott dash murphy dot com"
78 __email__ = "paul at scott dash murphy dot com"
79 __version__ = "0.12"
79 __version__ = "0.12"
80
80
81 import string
81 import string
82 import time
82 import time
83 import struct
83 import struct
84 import socket
84 import socket
85 import threading
85 import threading
86 import select
86 import select
87 import traceback
87 import traceback
88
88
89 __all__ = ["Zeroconf", "ServiceInfo", "ServiceBrowser"]
89 __all__ = ["Zeroconf", "ServiceInfo", "ServiceBrowser"]
90
90
91 # hook for threads
91 # hook for threads
92
92
93 globals()['_GLOBAL_DONE'] = 0
93 globals()['_GLOBAL_DONE'] = 0
94
94
95 # Some timing constants
95 # Some timing constants
96
96
97 _UNREGISTER_TIME = 125
97 _UNREGISTER_TIME = 125
98 _CHECK_TIME = 175
98 _CHECK_TIME = 175
99 _REGISTER_TIME = 225
99 _REGISTER_TIME = 225
100 _LISTENER_TIME = 200
100 _LISTENER_TIME = 200
101 _BROWSER_TIME = 500
101 _BROWSER_TIME = 500
102
102
103 # Some DNS constants
103 # Some DNS constants
104
104
105 _MDNS_ADDR = '224.0.0.251'
105 _MDNS_ADDR = '224.0.0.251'
106 _MDNS_PORT = 5353;
106 _MDNS_PORT = 5353;
107 _DNS_PORT = 53;
107 _DNS_PORT = 53;
108 _DNS_TTL = 60 * 60; # one hour default TTL
108 _DNS_TTL = 60 * 60; # one hour default TTL
109
109
110 _MAX_MSG_TYPICAL = 1460 # unused
110 _MAX_MSG_TYPICAL = 1460 # unused
111 _MAX_MSG_ABSOLUTE = 8972
111 _MAX_MSG_ABSOLUTE = 8972
112
112
113 _FLAGS_QR_MASK = 0x8000 # query response mask
113 _FLAGS_QR_MASK = 0x8000 # query response mask
114 _FLAGS_QR_QUERY = 0x0000 # query
114 _FLAGS_QR_QUERY = 0x0000 # query
115 _FLAGS_QR_RESPONSE = 0x8000 # response
115 _FLAGS_QR_RESPONSE = 0x8000 # response
116
116
117 _FLAGS_AA = 0x0400 # Authoritative answer
117 _FLAGS_AA = 0x0400 # Authoritative answer
118 _FLAGS_TC = 0x0200 # Truncated
118 _FLAGS_TC = 0x0200 # Truncated
119 _FLAGS_RD = 0x0100 # Recursion desired
119 _FLAGS_RD = 0x0100 # Recursion desired
120 _FLAGS_RA = 0x8000 # Recursion available
120 _FLAGS_RA = 0x8000 # Recursion available
121
121
122 _FLAGS_Z = 0x0040 # Zero
122 _FLAGS_Z = 0x0040 # Zero
123 _FLAGS_AD = 0x0020 # Authentic data
123 _FLAGS_AD = 0x0020 # Authentic data
124 _FLAGS_CD = 0x0010 # Checking disabled
124 _FLAGS_CD = 0x0010 # Checking disabled
125
125
126 _CLASS_IN = 1
126 _CLASS_IN = 1
127 _CLASS_CS = 2
127 _CLASS_CS = 2
128 _CLASS_CH = 3
128 _CLASS_CH = 3
129 _CLASS_HS = 4
129 _CLASS_HS = 4
130 _CLASS_NONE = 254
130 _CLASS_NONE = 254
131 _CLASS_ANY = 255
131 _CLASS_ANY = 255
132 _CLASS_MASK = 0x7FFF
132 _CLASS_MASK = 0x7FFF
133 _CLASS_UNIQUE = 0x8000
133 _CLASS_UNIQUE = 0x8000
134
134
135 _TYPE_A = 1
135 _TYPE_A = 1
136 _TYPE_NS = 2
136 _TYPE_NS = 2
137 _TYPE_MD = 3
137 _TYPE_MD = 3
138 _TYPE_MF = 4
138 _TYPE_MF = 4
139 _TYPE_CNAME = 5
139 _TYPE_CNAME = 5
140 _TYPE_SOA = 6
140 _TYPE_SOA = 6
141 _TYPE_MB = 7
141 _TYPE_MB = 7
142 _TYPE_MG = 8
142 _TYPE_MG = 8
143 _TYPE_MR = 9
143 _TYPE_MR = 9
144 _TYPE_NULL = 10
144 _TYPE_NULL = 10
145 _TYPE_WKS = 11
145 _TYPE_WKS = 11
146 _TYPE_PTR = 12
146 _TYPE_PTR = 12
147 _TYPE_HINFO = 13
147 _TYPE_HINFO = 13
148 _TYPE_MINFO = 14
148 _TYPE_MINFO = 14
149 _TYPE_MX = 15
149 _TYPE_MX = 15
150 _TYPE_TXT = 16
150 _TYPE_TXT = 16
151 _TYPE_AAAA = 28
151 _TYPE_AAAA = 28
152 _TYPE_SRV = 33
152 _TYPE_SRV = 33
153 _TYPE_ANY = 255
153 _TYPE_ANY = 255
154
154
155 # Mapping constants to names
155 # Mapping constants to names
156
156
157 _CLASSES = { _CLASS_IN : "in",
157 _CLASSES = { _CLASS_IN : "in",
158 _CLASS_CS : "cs",
158 _CLASS_CS : "cs",
159 _CLASS_CH : "ch",
159 _CLASS_CH : "ch",
160 _CLASS_HS : "hs",
160 _CLASS_HS : "hs",
161 _CLASS_NONE : "none",
161 _CLASS_NONE : "none",
162 _CLASS_ANY : "any" }
162 _CLASS_ANY : "any" }
163
163
164 _TYPES = { _TYPE_A : "a",
164 _TYPES = { _TYPE_A : "a",
165 _TYPE_NS : "ns",
165 _TYPE_NS : "ns",
166 _TYPE_MD : "md",
166 _TYPE_MD : "md",
167 _TYPE_MF : "mf",
167 _TYPE_MF : "mf",
168 _TYPE_CNAME : "cname",
168 _TYPE_CNAME : "cname",
169 _TYPE_SOA : "soa",
169 _TYPE_SOA : "soa",
170 _TYPE_MB : "mb",
170 _TYPE_MB : "mb",
171 _TYPE_MG : "mg",
171 _TYPE_MG : "mg",
172 _TYPE_MR : "mr",
172 _TYPE_MR : "mr",
173 _TYPE_NULL : "null",
173 _TYPE_NULL : "null",
174 _TYPE_WKS : "wks",
174 _TYPE_WKS : "wks",
175 _TYPE_PTR : "ptr",
175 _TYPE_PTR : "ptr",
176 _TYPE_HINFO : "hinfo",
176 _TYPE_HINFO : "hinfo",
177 _TYPE_MINFO : "minfo",
177 _TYPE_MINFO : "minfo",
178 _TYPE_MX : "mx",
178 _TYPE_MX : "mx",
179 _TYPE_TXT : "txt",
179 _TYPE_TXT : "txt",
180 _TYPE_AAAA : "quada",
180 _TYPE_AAAA : "quada",
181 _TYPE_SRV : "srv",
181 _TYPE_SRV : "srv",
182 _TYPE_ANY : "any" }
182 _TYPE_ANY : "any" }
183
183
184 # utility functions
184 # utility functions
185
185
186 def currentTimeMillis():
186 def currentTimeMillis():
187 """Current system time in milliseconds"""
187 """Current system time in milliseconds"""
188 return time.time() * 1000
188 return time.time() * 1000
189
189
190 # Exceptions
190 # Exceptions
191
191
192 class NonLocalNameException(Exception):
192 class NonLocalNameException(Exception):
193 pass
193 pass
194
194
195 class NonUniqueNameException(Exception):
195 class NonUniqueNameException(Exception):
196 pass
196 pass
197
197
198 class NamePartTooLongException(Exception):
198 class NamePartTooLongException(Exception):
199 pass
199 pass
200
200
201 class AbstractMethodException(Exception):
201 class AbstractMethodException(Exception):
202 pass
202 pass
203
203
204 class BadTypeInNameException(Exception):
204 class BadTypeInNameException(Exception):
205 pass
205 pass
206
206
207 class BadDomainName(Exception):
207 class BadDomainName(Exception):
208 def __init__(self, pos):
208 def __init__(self, pos):
209 Exception.__init__(self, "at position %s" % pos)
209 Exception.__init__(self, "at position %s" % pos)
210
210
211 class BadDomainNameCircular(BadDomainName):
211 class BadDomainNameCircular(BadDomainName):
212 pass
212 pass
213
213
214 # implementation classes
214 # implementation classes
215
215
216 class DNSEntry(object):
216 class DNSEntry(object):
217 """A DNS entry"""
217 """A DNS entry"""
218
218
219 def __init__(self, name, type, clazz):
219 def __init__(self, name, type, clazz):
220 self.key = string.lower(name)
220 self.key = string.lower(name)
221 self.name = name
221 self.name = name
222 self.type = type
222 self.type = type
223 self.clazz = clazz & _CLASS_MASK
223 self.clazz = clazz & _CLASS_MASK
224 self.unique = (clazz & _CLASS_UNIQUE) != 0
224 self.unique = (clazz & _CLASS_UNIQUE) != 0
225
225
226 def __eq__(self, other):
226 def __eq__(self, other):
227 """Equality test on name, type, and class"""
227 """Equality test on name, type, and class"""
228 if isinstance(other, DNSEntry):
228 if isinstance(other, DNSEntry):
229 return self.name == other.name and self.type == other.type and self.clazz == other.clazz
229 return self.name == other.name and self.type == other.type and self.clazz == other.clazz
230 return 0
230 return 0
231
231
232 def __ne__(self, other):
232 def __ne__(self, other):
233 """Non-equality test"""
233 """Non-equality test"""
234 return not self.__eq__(other)
234 return not self.__eq__(other)
235
235
236 def getClazz(self, clazz):
236 def getClazz(self, clazz):
237 """Class accessor"""
237 """Class accessor"""
238 try:
238 try:
239 return _CLASSES[clazz]
239 return _CLASSES[clazz]
240 except KeyError:
240 except KeyError:
241 return "?(%s)" % (clazz)
241 return "?(%s)" % (clazz)
242
242
243 def getType(self, type):
243 def getType(self, type):
244 """Type accessor"""
244 """Type accessor"""
245 try:
245 try:
246 return _TYPES[type]
246 return _TYPES[type]
247 except KeyError:
247 except KeyError:
248 return "?(%s)" % (type)
248 return "?(%s)" % (type)
249
249
250 def toString(self, hdr, other):
250 def toString(self, hdr, other):
251 """String representation with additional information"""
251 """String representation with additional information"""
252 result = "%s[%s,%s" % (hdr, self.getType(self.type), self.getClazz(self.clazz))
252 result = "%s[%s,%s" % (hdr, self.getType(self.type), self.getClazz(self.clazz))
253 if self.unique:
253 if self.unique:
254 result += "-unique,"
254 result += "-unique,"
255 else:
255 else:
256 result += ","
256 result += ","
257 result += self.name
257 result += self.name
258 if other is not None:
258 if other is not None:
259 result += ",%s]" % (other)
259 result += ",%s]" % (other)
260 else:
260 else:
261 result += "]"
261 result += "]"
262 return result
262 return result
263
263
264 class DNSQuestion(DNSEntry):
264 class DNSQuestion(DNSEntry):
265 """A DNS question entry"""
265 """A DNS question entry"""
266
266
267 def __init__(self, name, type, clazz):
267 def __init__(self, name, type, clazz):
268 if not name.endswith(".local."):
268 if not name.endswith(".local."):
269 raise NonLocalNameException(name)
269 raise NonLocalNameException(name)
270 DNSEntry.__init__(self, name, type, clazz)
270 DNSEntry.__init__(self, name, type, clazz)
271
271
272 def answeredBy(self, rec):
272 def answeredBy(self, rec):
273 """Returns true if the question is answered by the record"""
273 """Returns true if the question is answered by the record"""
274 return self.clazz == rec.clazz and (self.type == rec.type or self.type == _TYPE_ANY) and self.name == rec.name
274 return self.clazz == rec.clazz and (self.type == rec.type or self.type == _TYPE_ANY) and self.name == rec.name
275
275
276 def __repr__(self):
276 def __repr__(self):
277 """String representation"""
277 """String representation"""
278 return DNSEntry.toString(self, "question", None)
278 return DNSEntry.toString(self, "question", None)
279
279
280
280
281 class DNSRecord(DNSEntry):
281 class DNSRecord(DNSEntry):
282 """A DNS record - like a DNS entry, but has a TTL"""
282 """A DNS record - like a DNS entry, but has a TTL"""
283
283
284 def __init__(self, name, type, clazz, ttl):
284 def __init__(self, name, type, clazz, ttl):
285 DNSEntry.__init__(self, name, type, clazz)
285 DNSEntry.__init__(self, name, type, clazz)
286 self.ttl = ttl
286 self.ttl = ttl
287 self.created = currentTimeMillis()
287 self.created = currentTimeMillis()
288
288
289 def __eq__(self, other):
289 def __eq__(self, other):
290 """Tests equality as per DNSRecord"""
290 """Tests equality as per DNSRecord"""
291 if isinstance(other, DNSRecord):
291 if isinstance(other, DNSRecord):
292 return DNSEntry.__eq__(self, other)
292 return DNSEntry.__eq__(self, other)
293 return 0
293 return 0
294
294
295 def suppressedBy(self, msg):
295 def suppressedBy(self, msg):
296 """Returns true if any answer in a message can suffice for the
296 """Returns true if any answer in a message can suffice for the
297 information held in this record."""
297 information held in this record."""
298 for record in msg.answers:
298 for record in msg.answers:
299 if self.suppressedByAnswer(record):
299 if self.suppressedByAnswer(record):
300 return 1
300 return 1
301 return 0
301 return 0
302
302
303 def suppressedByAnswer(self, other):
303 def suppressedByAnswer(self, other):
304 """Returns true if another record has same name, type and class,
304 """Returns true if another record has same name, type and class,
305 and if its TTL is at least half of this record's."""
305 and if its TTL is at least half of this record's."""
306 if self == other and other.ttl > (self.ttl / 2):
306 if self == other and other.ttl > (self.ttl / 2):
307 return 1
307 return 1
308 return 0
308 return 0
309
309
310 def getExpirationTime(self, percent):
310 def getExpirationTime(self, percent):
311 """Returns the time at which this record will have expired
311 """Returns the time at which this record will have expired
312 by a certain percentage."""
312 by a certain percentage."""
313 return self.created + (percent * self.ttl * 10)
313 return self.created + (percent * self.ttl * 10)
314
314
315 def getRemainingTTL(self, now):
315 def getRemainingTTL(self, now):
316 """Returns the remaining TTL in seconds."""
316 """Returns the remaining TTL in seconds."""
317 return max(0, (self.getExpirationTime(100) - now) / 1000)
317 return max(0, (self.getExpirationTime(100) - now) / 1000)
318
318
319 def isExpired(self, now):
319 def isExpired(self, now):
320 """Returns true if this record has expired."""
320 """Returns true if this record has expired."""
321 return self.getExpirationTime(100) <= now
321 return self.getExpirationTime(100) <= now
322
322
323 def isStale(self, now):
323 def isStale(self, now):
324 """Returns true if this record is at least half way expired."""
324 """Returns true if this record is at least half way expired."""
325 return self.getExpirationTime(50) <= now
325 return self.getExpirationTime(50) <= now
326
326
327 def resetTTL(self, other):
327 def resetTTL(self, other):
328 """Sets this record's TTL and created time to that of
328 """Sets this record's TTL and created time to that of
329 another record."""
329 another record."""
330 self.created = other.created
330 self.created = other.created
331 self.ttl = other.ttl
331 self.ttl = other.ttl
332
332
333 def write(self, out):
333 def write(self, out):
334 """Abstract method"""
334 """Abstract method"""
335 raise AbstractMethodException
335 raise AbstractMethodException
336
336
337 def toString(self, other):
337 def toString(self, other):
338 """String representation with additional information"""
338 """String representation with additional information"""
339 arg = "%s/%s,%s" % (self.ttl, self.getRemainingTTL(currentTimeMillis()), other)
339 arg = "%s/%s,%s" % (self.ttl, self.getRemainingTTL(currentTimeMillis()), other)
340 return DNSEntry.toString(self, "record", arg)
340 return DNSEntry.toString(self, "record", arg)
341
341
342 class DNSAddress(DNSRecord):
342 class DNSAddress(DNSRecord):
343 """A DNS address record"""
343 """A DNS address record"""
344
344
345 def __init__(self, name, type, clazz, ttl, address):
345 def __init__(self, name, type, clazz, ttl, address):
346 DNSRecord.__init__(self, name, type, clazz, ttl)
346 DNSRecord.__init__(self, name, type, clazz, ttl)
347 self.address = address
347 self.address = address
348
348
349 def write(self, out):
349 def write(self, out):
350 """Used in constructing an outgoing packet"""
350 """Used in constructing an outgoing packet"""
351 out.writeString(self.address, len(self.address))
351 out.writeString(self.address, len(self.address))
352
352
353 def __eq__(self, other):
353 def __eq__(self, other):
354 """Tests equality on address"""
354 """Tests equality on address"""
355 if isinstance(other, DNSAddress):
355 if isinstance(other, DNSAddress):
356 return self.address == other.address
356 return self.address == other.address
357 return 0
357 return 0
358
358
359 def __repr__(self):
359 def __repr__(self):
360 """String representation"""
360 """String representation"""
361 try:
361 try:
362 return socket.inet_ntoa(self.address)
362 return socket.inet_ntoa(self.address)
363 except Exception:
363 except Exception:
364 return self.address
364 return self.address
365
365
366 class DNSHinfo(DNSRecord):
366 class DNSHinfo(DNSRecord):
367 """A DNS host information record"""
367 """A DNS host information record"""
368
368
369 def __init__(self, name, type, clazz, ttl, cpu, os):
369 def __init__(self, name, type, clazz, ttl, cpu, os):
370 DNSRecord.__init__(self, name, type, clazz, ttl)
370 DNSRecord.__init__(self, name, type, clazz, ttl)
371 self.cpu = cpu
371 self.cpu = cpu
372 self.os = os
372 self.os = os
373
373
374 def write(self, out):
374 def write(self, out):
375 """Used in constructing an outgoing packet"""
375 """Used in constructing an outgoing packet"""
376 out.writeString(self.cpu, len(self.cpu))
376 out.writeString(self.cpu, len(self.cpu))
377 out.writeString(self.os, len(self.os))
377 out.writeString(self.os, len(self.os))
378
378
379 def __eq__(self, other):
379 def __eq__(self, other):
380 """Tests equality on cpu and os"""
380 """Tests equality on cpu and os"""
381 if isinstance(other, DNSHinfo):
381 if isinstance(other, DNSHinfo):
382 return self.cpu == other.cpu and self.os == other.os
382 return self.cpu == other.cpu and self.os == other.os
383 return 0
383 return 0
384
384
385 def __repr__(self):
385 def __repr__(self):
386 """String representation"""
386 """String representation"""
387 return self.cpu + " " + self.os
387 return self.cpu + " " + self.os
388
388
389 class DNSPointer(DNSRecord):
389 class DNSPointer(DNSRecord):
390 """A DNS pointer record"""
390 """A DNS pointer record"""
391
391
392 def __init__(self, name, type, clazz, ttl, alias):
392 def __init__(self, name, type, clazz, ttl, alias):
393 DNSRecord.__init__(self, name, type, clazz, ttl)
393 DNSRecord.__init__(self, name, type, clazz, ttl)
394 self.alias = alias
394 self.alias = alias
395
395
396 def write(self, out):
396 def write(self, out):
397 """Used in constructing an outgoing packet"""
397 """Used in constructing an outgoing packet"""
398 out.writeName(self.alias)
398 out.writeName(self.alias)
399
399
400 def __eq__(self, other):
400 def __eq__(self, other):
401 """Tests equality on alias"""
401 """Tests equality on alias"""
402 if isinstance(other, DNSPointer):
402 if isinstance(other, DNSPointer):
403 return self.alias == other.alias
403 return self.alias == other.alias
404 return 0
404 return 0
405
405
406 def __repr__(self):
406 def __repr__(self):
407 """String representation"""
407 """String representation"""
408 return self.toString(self.alias)
408 return self.toString(self.alias)
409
409
410 class DNSText(DNSRecord):
410 class DNSText(DNSRecord):
411 """A DNS text record"""
411 """A DNS text record"""
412
412
413 def __init__(self, name, type, clazz, ttl, text):
413 def __init__(self, name, type, clazz, ttl, text):
414 DNSRecord.__init__(self, name, type, clazz, ttl)
414 DNSRecord.__init__(self, name, type, clazz, ttl)
415 self.text = text
415 self.text = text
416
416
417 def write(self, out):
417 def write(self, out):
418 """Used in constructing an outgoing packet"""
418 """Used in constructing an outgoing packet"""
419 out.writeString(self.text, len(self.text))
419 out.writeString(self.text, len(self.text))
420
420
421 def __eq__(self, other):
421 def __eq__(self, other):
422 """Tests equality on text"""
422 """Tests equality on text"""
423 if isinstance(other, DNSText):
423 if isinstance(other, DNSText):
424 return self.text == other.text
424 return self.text == other.text
425 return 0
425 return 0
426
426
427 def __repr__(self):
427 def __repr__(self):
428 """String representation"""
428 """String representation"""
429 if len(self.text) > 10:
429 if len(self.text) > 10:
430 return self.toString(self.text[:7] + "...")
430 return self.toString(self.text[:7] + "...")
431 else:
431 else:
432 return self.toString(self.text)
432 return self.toString(self.text)
433
433
434 class DNSService(DNSRecord):
434 class DNSService(DNSRecord):
435 """A DNS service record"""
435 """A DNS service record"""
436
436
437 def __init__(self, name, type, clazz, ttl, priority, weight, port, server):
437 def __init__(self, name, type, clazz, ttl, priority, weight, port, server):
438 DNSRecord.__init__(self, name, type, clazz, ttl)
438 DNSRecord.__init__(self, name, type, clazz, ttl)
439 self.priority = priority
439 self.priority = priority
440 self.weight = weight
440 self.weight = weight
441 self.port = port
441 self.port = port
442 self.server = server
442 self.server = server
443
443
444 def write(self, out):
444 def write(self, out):
445 """Used in constructing an outgoing packet"""
445 """Used in constructing an outgoing packet"""
446 out.writeShort(self.priority)
446 out.writeShort(self.priority)
447 out.writeShort(self.weight)
447 out.writeShort(self.weight)
448 out.writeShort(self.port)
448 out.writeShort(self.port)
449 out.writeName(self.server)
449 out.writeName(self.server)
450
450
451 def __eq__(self, other):
451 def __eq__(self, other):
452 """Tests equality on priority, weight, port and server"""
452 """Tests equality on priority, weight, port and server"""
453 if isinstance(other, DNSService):
453 if isinstance(other, DNSService):
454 return self.priority == other.priority and self.weight == other.weight and self.port == other.port and self.server == other.server
454 return self.priority == other.priority and self.weight == other.weight and self.port == other.port and self.server == other.server
455 return 0
455 return 0
456
456
457 def __repr__(self):
457 def __repr__(self):
458 """String representation"""
458 """String representation"""
459 return self.toString("%s:%s" % (self.server, self.port))
459 return self.toString("%s:%s" % (self.server, self.port))
460
460
461 class DNSIncoming(object):
461 class DNSIncoming(object):
462 """Object representation of an incoming DNS packet"""
462 """Object representation of an incoming DNS packet"""
463
463
464 def __init__(self, data):
464 def __init__(self, data):
465 """Constructor from string holding bytes of packet"""
465 """Constructor from string holding bytes of packet"""
466 self.offset = 0
466 self.offset = 0
467 self.data = data
467 self.data = data
468 self.questions = []
468 self.questions = []
469 self.answers = []
469 self.answers = []
470 self.numQuestions = 0
470 self.numQuestions = 0
471 self.numAnswers = 0
471 self.numAnswers = 0
472 self.numAuthorities = 0
472 self.numAuthorities = 0
473 self.numAdditionals = 0
473 self.numAdditionals = 0
474
474
475 self.readHeader()
475 self.readHeader()
476 self.readQuestions()
476 self.readQuestions()
477 self.readOthers()
477 self.readOthers()
478
478
479 def readHeader(self):
479 def readHeader(self):
480 """Reads header portion of packet"""
480 """Reads header portion of packet"""
481 format = '!HHHHHH'
481 format = '!HHHHHH'
482 length = struct.calcsize(format)
482 length = struct.calcsize(format)
483 info = struct.unpack(format, self.data[self.offset:self.offset+length])
483 info = struct.unpack(format, self.data[self.offset:self.offset+length])
484 self.offset += length
484 self.offset += length
485
485
486 self.id = info[0]
486 self.id = info[0]
487 self.flags = info[1]
487 self.flags = info[1]
488 self.numQuestions = info[2]
488 self.numQuestions = info[2]
489 self.numAnswers = info[3]
489 self.numAnswers = info[3]
490 self.numAuthorities = info[4]
490 self.numAuthorities = info[4]
491 self.numAdditionals = info[5]
491 self.numAdditionals = info[5]
492
492
493 def readQuestions(self):
493 def readQuestions(self):
494 """Reads questions section of packet"""
494 """Reads questions section of packet"""
495 format = '!HH'
495 format = '!HH'
496 length = struct.calcsize(format)
496 length = struct.calcsize(format)
497 for i in range(0, self.numQuestions):
497 for i in range(0, self.numQuestions):
498 name = self.readName()
498 name = self.readName()
499 info = struct.unpack(format, self.data[self.offset:self.offset+length])
499 info = struct.unpack(format, self.data[self.offset:self.offset+length])
500 self.offset += length
500 self.offset += length
501
501
502 try:
502 try:
503 question = DNSQuestion(name, info[0], info[1])
503 question = DNSQuestion(name, info[0], info[1])
504 self.questions.append(question)
504 self.questions.append(question)
505 except NonLocalNameException:
505 except NonLocalNameException:
506 pass
506 pass
507
507
508 def readInt(self):
508 def readInt(self):
509 """Reads an integer from the packet"""
509 """Reads an integer from the packet"""
510 format = '!I'
510 format = '!I'
511 length = struct.calcsize(format)
511 length = struct.calcsize(format)
512 info = struct.unpack(format, self.data[self.offset:self.offset+length])
512 info = struct.unpack(format, self.data[self.offset:self.offset+length])
513 self.offset += length
513 self.offset += length
514 return info[0]
514 return info[0]
515
515
516 def readCharacterString(self):
516 def readCharacterString(self):
517 """Reads a character string from the packet"""
517 """Reads a character string from the packet"""
518 length = ord(self.data[self.offset])
518 length = ord(self.data[self.offset])
519 self.offset += 1
519 self.offset += 1
520 return self.readString(length)
520 return self.readString(length)
521
521
522 def readString(self, len):
522 def readString(self, len):
523 """Reads a string of a given length from the packet"""
523 """Reads a string of a given length from the packet"""
524 format = '!' + str(len) + 's'
524 format = '!' + str(len) + 's'
525 length = struct.calcsize(format)
525 length = struct.calcsize(format)
526 info = struct.unpack(format, self.data[self.offset:self.offset+length])
526 info = struct.unpack(format, self.data[self.offset:self.offset+length])
527 self.offset += length
527 self.offset += length
528 return info[0]
528 return info[0]
529
529
530 def readUnsignedShort(self):
530 def readUnsignedShort(self):
531 """Reads an unsigned short from the packet"""
531 """Reads an unsigned short from the packet"""
532 format = '!H'
532 format = '!H'
533 length = struct.calcsize(format)
533 length = struct.calcsize(format)
534 info = struct.unpack(format, self.data[self.offset:self.offset+length])
534 info = struct.unpack(format, self.data[self.offset:self.offset+length])
535 self.offset += length
535 self.offset += length
536 return info[0]
536 return info[0]
537
537
538 def readOthers(self):
538 def readOthers(self):
539 """Reads the answers, authorities and additionals section of the packet"""
539 """Reads the answers, authorities and additionals section of the packet"""
540 format = '!HHiH'
540 format = '!HHiH'
541 length = struct.calcsize(format)
541 length = struct.calcsize(format)
542 n = self.numAnswers + self.numAuthorities + self.numAdditionals
542 n = self.numAnswers + self.numAuthorities + self.numAdditionals
543 for i in range(0, n):
543 for i in range(0, n):
544 domain = self.readName()
544 domain = self.readName()
545 info = struct.unpack(format, self.data[self.offset:self.offset+length])
545 info = struct.unpack(format, self.data[self.offset:self.offset+length])
546 self.offset += length
546 self.offset += length
547
547
548 rec = None
548 rec = None
549 if info[0] == _TYPE_A:
549 if info[0] == _TYPE_A:
550 rec = DNSAddress(domain, info[0], info[1], info[2], self.readString(4))
550 rec = DNSAddress(domain, info[0], info[1], info[2], self.readString(4))
551 elif info[0] == _TYPE_CNAME or info[0] == _TYPE_PTR:
551 elif info[0] == _TYPE_CNAME or info[0] == _TYPE_PTR:
552 rec = DNSPointer(domain, info[0], info[1], info[2], self.readName())
552 rec = DNSPointer(domain, info[0], info[1], info[2], self.readName())
553 elif info[0] == _TYPE_TXT:
553 elif info[0] == _TYPE_TXT:
554 rec = DNSText(domain, info[0], info[1], info[2], self.readString(info[3]))
554 rec = DNSText(domain, info[0], info[1], info[2], self.readString(info[3]))
555 elif info[0] == _TYPE_SRV:
555 elif info[0] == _TYPE_SRV:
556 rec = DNSService(domain, info[0], info[1], info[2], self.readUnsignedShort(), self.readUnsignedShort(), self.readUnsignedShort(), self.readName())
556 rec = DNSService(domain, info[0], info[1], info[2], self.readUnsignedShort(), self.readUnsignedShort(), self.readUnsignedShort(), self.readName())
557 elif info[0] == _TYPE_HINFO:
557 elif info[0] == _TYPE_HINFO:
558 rec = DNSHinfo(domain, info[0], info[1], info[2], self.readCharacterString(), self.readCharacterString())
558 rec = DNSHinfo(domain, info[0], info[1], info[2], self.readCharacterString(), self.readCharacterString())
559 elif info[0] == _TYPE_AAAA:
559 elif info[0] == _TYPE_AAAA:
560 rec = DNSAddress(domain, info[0], info[1], info[2], self.readString(16))
560 rec = DNSAddress(domain, info[0], info[1], info[2], self.readString(16))
561 else:
561 else:
562 # Try to ignore types we don't know about
562 # Try to ignore types we don't know about
563 # this may mean the rest of the name is
563 # this may mean the rest of the name is
564 # unable to be parsed, and may show errors
564 # unable to be parsed, and may show errors
565 # so this is left for debugging. New types
565 # so this is left for debugging. New types
566 # encountered need to be parsed properly.
566 # encountered need to be parsed properly.
567 #
567 #
568 #print "UNKNOWN TYPE = " + str(info[0])
568 #print "UNKNOWN TYPE = " + str(info[0])
569 #raise BadTypeInNameException
569 #raise BadTypeInNameException
570 self.offset += info[3]
570 self.offset += info[3]
571
571
572 if rec is not None:
572 if rec is not None:
573 self.answers.append(rec)
573 self.answers.append(rec)
574
574
575 def isQuery(self):
575 def isQuery(self):
576 """Returns true if this is a query"""
576 """Returns true if this is a query"""
577 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY
577 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY
578
578
579 def isResponse(self):
579 def isResponse(self):
580 """Returns true if this is a response"""
580 """Returns true if this is a response"""
581 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
581 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
582
582
583 def readUTF(self, offset, len):
583 def readUTF(self, offset, len):
584 """Reads a UTF-8 string of a given length from the packet"""
584 """Reads a UTF-8 string of a given length from the packet"""
585 return self.data[offset:offset+len].decode('utf-8')
585 return self.data[offset:offset+len].decode('utf-8')
586
586
587 def readName(self):
587 def readName(self):
588 """Reads a domain name from the packet"""
588 """Reads a domain name from the packet"""
589 result = ''
589 result = ''
590 off = self.offset
590 off = self.offset
591 next = -1
591 next = -1
592 first = off
592 first = off
593
593
594 while True:
594 while True:
595 len = ord(self.data[off])
595 len = ord(self.data[off])
596 off += 1
596 off += 1
597 if len == 0:
597 if len == 0:
598 break
598 break
599 t = len & 0xC0
599 t = len & 0xC0
600 if t == 0x00:
600 if t == 0x00:
601 result = ''.join((result, self.readUTF(off, len) + '.'))
601 result = ''.join((result, self.readUTF(off, len) + '.'))
602 off += len
602 off += len
603 elif t == 0xC0:
603 elif t == 0xC0:
604 if next < 0:
604 if next < 0:
605 next = off + 1
605 next = off + 1
606 off = ((len & 0x3F) << 8) | ord(self.data[off])
606 off = ((len & 0x3F) << 8) | ord(self.data[off])
607 if off >= first:
607 if off >= first:
608 raise BadDomainNameCircular(off)
608 raise BadDomainNameCircular(off)
609 first = off
609 first = off
610 else:
610 else:
611 raise BadDomainName(off)
611 raise BadDomainName(off)
612
612
613 if next >= 0:
613 if next >= 0:
614 self.offset = next
614 self.offset = next
615 else:
615 else:
616 self.offset = off
616 self.offset = off
617
617
618 return result
618 return result
619
619
620
620
621 class DNSOutgoing(object):
621 class DNSOutgoing(object):
622 """Object representation of an outgoing packet"""
622 """Object representation of an outgoing packet"""
623
623
624 def __init__(self, flags, multicast = 1):
624 def __init__(self, flags, multicast = 1):
625 self.finished = 0
625 self.finished = 0
626 self.id = 0
626 self.id = 0
627 self.multicast = multicast
627 self.multicast = multicast
628 self.flags = flags
628 self.flags = flags
629 self.names = {}
629 self.names = {}
630 self.data = []
630 self.data = []
631 self.size = 12
631 self.size = 12
632
632
633 self.questions = []
633 self.questions = []
634 self.answers = []
634 self.answers = []
635 self.authorities = []
635 self.authorities = []
636 self.additionals = []
636 self.additionals = []
637
637
638 def addQuestion(self, record):
638 def addQuestion(self, record):
639 """Adds a question"""
639 """Adds a question"""
640 self.questions.append(record)
640 self.questions.append(record)
641
641
642 def addAnswer(self, inp, record):
642 def addAnswer(self, inp, record):
643 """Adds an answer"""
643 """Adds an answer"""
644 if not record.suppressedBy(inp):
644 if not record.suppressedBy(inp):
645 self.addAnswerAtTime(record, 0)
645 self.addAnswerAtTime(record, 0)
646
646
647 def addAnswerAtTime(self, record, now):
647 def addAnswerAtTime(self, record, now):
648 """Adds an answer if if does not expire by a certain time"""
648 """Adds an answer if if does not expire by a certain time"""
649 if record is not None:
649 if record is not None:
650 if now == 0 or not record.isExpired(now):
650 if now == 0 or not record.isExpired(now):
651 self.answers.append((record, now))
651 self.answers.append((record, now))
652
652
653 def addAuthoritativeAnswer(self, record):
653 def addAuthoritativeAnswer(self, record):
654 """Adds an authoritative answer"""
654 """Adds an authoritative answer"""
655 self.authorities.append(record)
655 self.authorities.append(record)
656
656
657 def addAdditionalAnswer(self, record):
657 def addAdditionalAnswer(self, record):
658 """Adds an additional answer"""
658 """Adds an additional answer"""
659 self.additionals.append(record)
659 self.additionals.append(record)
660
660
661 def writeByte(self, value):
661 def writeByte(self, value):
662 """Writes a single byte to the packet"""
662 """Writes a single byte to the packet"""
663 format = '!c'
663 format = '!c'
664 self.data.append(struct.pack(format, chr(value)))
664 self.data.append(struct.pack(format, chr(value)))
665 self.size += 1
665 self.size += 1
666
666
667 def insertShort(self, index, value):
667 def insertShort(self, index, value):
668 """Inserts an unsigned short in a certain position in the packet"""
668 """Inserts an unsigned short in a certain position in the packet"""
669 format = '!H'
669 format = '!H'
670 self.data.insert(index, struct.pack(format, value))
670 self.data.insert(index, struct.pack(format, value))
671 self.size += 2
671 self.size += 2
672
672
673 def writeShort(self, value):
673 def writeShort(self, value):
674 """Writes an unsigned short to the packet"""
674 """Writes an unsigned short to the packet"""
675 format = '!H'
675 format = '!H'
676 self.data.append(struct.pack(format, value))
676 self.data.append(struct.pack(format, value))
677 self.size += 2
677 self.size += 2
678
678
679 def writeInt(self, value):
679 def writeInt(self, value):
680 """Writes an unsigned integer to the packet"""
680 """Writes an unsigned integer to the packet"""
681 format = '!I'
681 format = '!I'
682 self.data.append(struct.pack(format, int(value)))
682 self.data.append(struct.pack(format, int(value)))
683 self.size += 4
683 self.size += 4
684
684
685 def writeString(self, value, length):
685 def writeString(self, value, length):
686 """Writes a string to the packet"""
686 """Writes a string to the packet"""
687 format = '!' + str(length) + 's'
687 format = '!' + str(length) + 's'
688 self.data.append(struct.pack(format, value))
688 self.data.append(struct.pack(format, value))
689 self.size += length
689 self.size += length
690
690
691 def writeUTF(self, s):
691 def writeUTF(self, s):
692 """Writes a UTF-8 string of a given length to the packet"""
692 """Writes a UTF-8 string of a given length to the packet"""
693 utfstr = s.encode('utf-8')
693 utfstr = s.encode('utf-8')
694 length = len(utfstr)
694 length = len(utfstr)
695 if length > 64:
695 if length > 64:
696 raise NamePartTooLongException
696 raise NamePartTooLongException
697 self.writeByte(length)
697 self.writeByte(length)
698 self.writeString(utfstr, length)
698 self.writeString(utfstr, length)
699
699
700 def writeName(self, name):
700 def writeName(self, name):
701 """Writes a domain name to the packet"""
701 """Writes a domain name to the packet"""
702
702
703 try:
703 try:
704 # Find existing instance of this name in packet
704 # Find existing instance of this name in packet
705 #
705 #
706 index = self.names[name]
706 index = self.names[name]
707 except KeyError:
707 except KeyError:
708 # No record of this name already, so write it
708 # No record of this name already, so write it
709 # out as normal, recording the location of the name
709 # out as normal, recording the location of the name
710 # for future pointers to it.
710 # for future pointers to it.
711 #
711 #
712 self.names[name] = self.size
712 self.names[name] = self.size
713 parts = name.split('.')
713 parts = name.split('.')
714 if parts[-1] == '':
714 if parts[-1] == '':
715 parts = parts[:-1]
715 parts = parts[:-1]
716 for part in parts:
716 for part in parts:
717 self.writeUTF(part)
717 self.writeUTF(part)
718 self.writeByte(0)
718 self.writeByte(0)
719 return
719 return
720
720
721 # An index was found, so write a pointer to it
721 # An index was found, so write a pointer to it
722 #
722 #
723 self.writeByte((index >> 8) | 0xC0)
723 self.writeByte((index >> 8) | 0xC0)
724 self.writeByte(index)
724 self.writeByte(index)
725
725
726 def writeQuestion(self, question):
726 def writeQuestion(self, question):
727 """Writes a question to the packet"""
727 """Writes a question to the packet"""
728 self.writeName(question.name)
728 self.writeName(question.name)
729 self.writeShort(question.type)
729 self.writeShort(question.type)
730 self.writeShort(question.clazz)
730 self.writeShort(question.clazz)
731
731
732 def writeRecord(self, record, now):
732 def writeRecord(self, record, now):
733 """Writes a record (answer, authoritative answer, additional) to
733 """Writes a record (answer, authoritative answer, additional) to
734 the packet"""
734 the packet"""
735 self.writeName(record.name)
735 self.writeName(record.name)
736 self.writeShort(record.type)
736 self.writeShort(record.type)
737 if record.unique and self.multicast:
737 if record.unique and self.multicast:
738 self.writeShort(record.clazz | _CLASS_UNIQUE)
738 self.writeShort(record.clazz | _CLASS_UNIQUE)
739 else:
739 else:
740 self.writeShort(record.clazz)
740 self.writeShort(record.clazz)
741 if now == 0:
741 if now == 0:
742 self.writeInt(record.ttl)
742 self.writeInt(record.ttl)
743 else:
743 else:
744 self.writeInt(record.getRemainingTTL(now))
744 self.writeInt(record.getRemainingTTL(now))
745 index = len(self.data)
745 index = len(self.data)
746 # Adjust size for the short we will write before this record
746 # Adjust size for the short we will write before this record
747 #
747 #
748 self.size += 2
748 self.size += 2
749 record.write(self)
749 record.write(self)
750 self.size -= 2
750 self.size -= 2
751
751
752 length = len(''.join(self.data[index:]))
752 length = len(''.join(self.data[index:]))
753 self.insertShort(index, length) # Here is the short we adjusted for
753 self.insertShort(index, length) # Here is the short we adjusted for
754
754
755 def packet(self):
755 def packet(self):
756 """Returns a string containing the packet's bytes
756 """Returns a string containing the packet's bytes
757
757
758 No further parts should be added to the packet once this
758 No further parts should be added to the packet once this
759 is done."""
759 is done."""
760 if not self.finished:
760 if not self.finished:
761 self.finished = 1
761 self.finished = 1
762 for question in self.questions:
762 for question in self.questions:
763 self.writeQuestion(question)
763 self.writeQuestion(question)
764 for answer, time in self.answers:
764 for answer, time in self.answers:
765 self.writeRecord(answer, time)
765 self.writeRecord(answer, time)
766 for authority in self.authorities:
766 for authority in self.authorities:
767 self.writeRecord(authority, 0)
767 self.writeRecord(authority, 0)
768 for additional in self.additionals:
768 for additional in self.additionals:
769 self.writeRecord(additional, 0)
769 self.writeRecord(additional, 0)
770
770
771 self.insertShort(0, len(self.additionals))
771 self.insertShort(0, len(self.additionals))
772 self.insertShort(0, len(self.authorities))
772 self.insertShort(0, len(self.authorities))
773 self.insertShort(0, len(self.answers))
773 self.insertShort(0, len(self.answers))
774 self.insertShort(0, len(self.questions))
774 self.insertShort(0, len(self.questions))
775 self.insertShort(0, self.flags)
775 self.insertShort(0, self.flags)
776 if self.multicast:
776 if self.multicast:
777 self.insertShort(0, 0)
777 self.insertShort(0, 0)
778 else:
778 else:
779 self.insertShort(0, self.id)
779 self.insertShort(0, self.id)
780 return ''.join(self.data)
780 return ''.join(self.data)
781
781
782
782
783 class DNSCache(object):
783 class DNSCache(object):
784 """A cache of DNS entries"""
784 """A cache of DNS entries"""
785
785
786 def __init__(self):
786 def __init__(self):
787 self.cache = {}
787 self.cache = {}
788
788
789 def add(self, entry):
789 def add(self, entry):
790 """Adds an entry"""
790 """Adds an entry"""
791 try:
791 try:
792 list = self.cache[entry.key]
792 list = self.cache[entry.key]
793 except KeyError:
793 except KeyError:
794 list = self.cache[entry.key] = []
794 list = self.cache[entry.key] = []
795 list.append(entry)
795 list.append(entry)
796
796
797 def remove(self, entry):
797 def remove(self, entry):
798 """Removes an entry"""
798 """Removes an entry"""
799 try:
799 try:
800 list = self.cache[entry.key]
800 list = self.cache[entry.key]
801 list.remove(entry)
801 list.remove(entry)
802 except KeyError:
802 except KeyError:
803 pass
803 pass
804
804
805 def get(self, entry):
805 def get(self, entry):
806 """Gets an entry by key. Will return None if there is no
806 """Gets an entry by key. Will return None if there is no
807 matching entry."""
807 matching entry."""
808 try:
808 try:
809 list = self.cache[entry.key]
809 list = self.cache[entry.key]
810 return list[list.index(entry)]
810 return list[list.index(entry)]
811 except (KeyError, ValueError):
811 except (KeyError, ValueError):
812 return None
812 return None
813
813
814 def getByDetails(self, name, type, clazz):
814 def getByDetails(self, name, type, clazz):
815 """Gets an entry by details. Will return None if there is
815 """Gets an entry by details. Will return None if there is
816 no matching entry."""
816 no matching entry."""
817 entry = DNSEntry(name, type, clazz)
817 entry = DNSEntry(name, type, clazz)
818 return self.get(entry)
818 return self.get(entry)
819
819
820 def entriesWithName(self, name):
820 def entriesWithName(self, name):
821 """Returns a list of entries whose key matches the name."""
821 """Returns a list of entries whose key matches the name."""
822 try:
822 try:
823 return self.cache[name]
823 return self.cache[name]
824 except KeyError:
824 except KeyError:
825 return []
825 return []
826
826
827 def entries(self):
827 def entries(self):
828 """Returns a list of all entries"""
828 """Returns a list of all entries"""
829 def add(x, y): return x+y
829 def add(x, y): return x+y
830 try:
830 try:
831 return reduce(add, self.cache.values())
831 return reduce(add, self.cache.values())
832 except Exception:
832 except Exception:
833 return []
833 return []
834
834
835
835
836 class Engine(threading.Thread):
836 class Engine(threading.Thread):
837 """An engine wraps read access to sockets, allowing objects that
837 """An engine wraps read access to sockets, allowing objects that
838 need to receive data from sockets to be called back when the
838 need to receive data from sockets to be called back when the
839 sockets are ready.
839 sockets are ready.
840
840
841 A reader needs a handle_read() method, which is called when the socket
841 A reader needs a handle_read() method, which is called when the socket
842 it is interested in is ready for reading.
842 it is interested in is ready for reading.
843
843
844 Writers are not implemented here, because we only send short
844 Writers are not implemented here, because we only send short
845 packets.
845 packets.
846 """
846 """
847
847
848 def __init__(self, zeroconf):
848 def __init__(self, zeroconf):
849 threading.Thread.__init__(self)
849 threading.Thread.__init__(self)
850 self.zeroconf = zeroconf
850 self.zeroconf = zeroconf
851 self.readers = {} # maps socket to reader
851 self.readers = {} # maps socket to reader
852 self.timeout = 5
852 self.timeout = 5
853 self.condition = threading.Condition()
853 self.condition = threading.Condition()
854 self.start()
854 self.start()
855
855
856 def run(self):
856 def run(self):
857 while not globals()['_GLOBAL_DONE']:
857 while not globals()['_GLOBAL_DONE']:
858 rs = self.getReaders()
858 rs = self.getReaders()
859 if len(rs) == 0:
859 if len(rs) == 0:
860 # No sockets to manage, but we wait for the timeout
860 # No sockets to manage, but we wait for the timeout
861 # or addition of a socket
861 # or addition of a socket
862 #
862 #
863 self.condition.acquire()
863 self.condition.acquire()
864 self.condition.wait(self.timeout)
864 self.condition.wait(self.timeout)
865 self.condition.release()
865 self.condition.release()
866 else:
866 else:
867 try:
867 try:
868 rr, wr, er = select.select(rs, [], [], self.timeout)
868 rr, wr, er = select.select(rs, [], [], self.timeout)
869 for socket in rr:
869 for socket in rr:
870 try:
870 try:
871 self.readers[socket].handle_read()
871 self.readers[socket].handle_read()
872 except Exception:
872 except Exception:
873 if not globals()['_GLOBAL_DONE']:
873 if not globals()['_GLOBAL_DONE']:
874 traceback.print_exc()
874 traceback.print_exc()
875 except Exception:
875 except Exception:
876 pass
876 pass
877
877
878 def getReaders(self):
878 def getReaders(self):
879 self.condition.acquire()
879 self.condition.acquire()
880 result = self.readers.keys()
880 result = self.readers.keys()
881 self.condition.release()
881 self.condition.release()
882 return result
882 return result
883
883
884 def addReader(self, reader, socket):
884 def addReader(self, reader, socket):
885 self.condition.acquire()
885 self.condition.acquire()
886 self.readers[socket] = reader
886 self.readers[socket] = reader
887 self.condition.notify()
887 self.condition.notify()
888 self.condition.release()
888 self.condition.release()
889
889
890 def delReader(self, socket):
890 def delReader(self, socket):
891 self.condition.acquire()
891 self.condition.acquire()
892 del(self.readers[socket])
892 del(self.readers[socket])
893 self.condition.notify()
893 self.condition.notify()
894 self.condition.release()
894 self.condition.release()
895
895
896 def notify(self):
896 def notify(self):
897 self.condition.acquire()
897 self.condition.acquire()
898 self.condition.notify()
898 self.condition.notify()
899 self.condition.release()
899 self.condition.release()
900
900
901 class Listener(object):
901 class Listener(object):
902 """A Listener is used by this module to listen on the multicast
902 """A Listener is used by this module to listen on the multicast
903 group to which DNS messages are sent, allowing the implementation
903 group to which DNS messages are sent, allowing the implementation
904 to cache information as it arrives.
904 to cache information as it arrives.
905
905
906 It requires registration with an Engine object in order to have
906 It requires registration with an Engine object in order to have
907 the read() method called when a socket is available for reading."""
907 the read() method called when a socket is available for reading."""
908
908
909 def __init__(self, zeroconf):
909 def __init__(self, zeroconf):
910 self.zeroconf = zeroconf
910 self.zeroconf = zeroconf
911 self.zeroconf.engine.addReader(self, self.zeroconf.socket)
911 self.zeroconf.engine.addReader(self, self.zeroconf.socket)
912
912
913 def handle_read(self):
913 def handle_read(self):
914 data, (addr, port) = self.zeroconf.socket.recvfrom(_MAX_MSG_ABSOLUTE)
914 data, (addr, port) = self.zeroconf.socket.recvfrom(_MAX_MSG_ABSOLUTE)
915 self.data = data
915 self.data = data
916 msg = DNSIncoming(data)
916 msg = DNSIncoming(data)
917 if msg.isQuery():
917 if msg.isQuery():
918 # Always multicast responses
918 # Always multicast responses
919 #
919 #
920 if port == _MDNS_PORT:
920 if port == _MDNS_PORT:
921 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT)
921 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT)
922 # If it's not a multicast query, reply via unicast
922 # If it's not a multicast query, reply via unicast
923 # and multicast
923 # and multicast
924 #
924 #
925 elif port == _DNS_PORT:
925 elif port == _DNS_PORT:
926 self.zeroconf.handleQuery(msg, addr, port)
926 self.zeroconf.handleQuery(msg, addr, port)
927 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT)
927 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT)
928 else:
928 else:
929 self.zeroconf.handleResponse(msg)
929 self.zeroconf.handleResponse(msg)
930
930
931
931
932 class Reaper(threading.Thread):
932 class Reaper(threading.Thread):
933 """A Reaper is used by this module to remove cache entries that
933 """A Reaper is used by this module to remove cache entries that
934 have expired."""
934 have expired."""
935
935
936 def __init__(self, zeroconf):
936 def __init__(self, zeroconf):
937 threading.Thread.__init__(self)
937 threading.Thread.__init__(self)
938 self.zeroconf = zeroconf
938 self.zeroconf = zeroconf
939 self.start()
939 self.start()
940
940
941 def run(self):
941 def run(self):
942 while True:
942 while True:
943 self.zeroconf.wait(10 * 1000)
943 self.zeroconf.wait(10 * 1000)
944 if globals()['_GLOBAL_DONE']:
944 if globals()['_GLOBAL_DONE']:
945 return
945 return
946 now = currentTimeMillis()
946 now = currentTimeMillis()
947 for record in self.zeroconf.cache.entries():
947 for record in self.zeroconf.cache.entries():
948 if record.isExpired(now):
948 if record.isExpired(now):
949 self.zeroconf.updateRecord(now, record)
949 self.zeroconf.updateRecord(now, record)
950 self.zeroconf.cache.remove(record)
950 self.zeroconf.cache.remove(record)
951
951
952
952
953 class ServiceBrowser(threading.Thread):
953 class ServiceBrowser(threading.Thread):
954 """Used to browse for a service of a specific type.
954 """Used to browse for a service of a specific type.
955
955
956 The listener object will have its addService() and
956 The listener object will have its addService() and
957 removeService() methods called when this browser
957 removeService() methods called when this browser
958 discovers changes in the services availability."""
958 discovers changes in the services availability."""
959
959
960 def __init__(self, zeroconf, type, listener):
960 def __init__(self, zeroconf, type, listener):
961 """Creates a browser for a specific type"""
961 """Creates a browser for a specific type"""
962 threading.Thread.__init__(self)
962 threading.Thread.__init__(self)
963 self.zeroconf = zeroconf
963 self.zeroconf = zeroconf
964 self.type = type
964 self.type = type
965 self.listener = listener
965 self.listener = listener
966 self.services = {}
966 self.services = {}
967 self.nextTime = currentTimeMillis()
967 self.nextTime = currentTimeMillis()
968 self.delay = _BROWSER_TIME
968 self.delay = _BROWSER_TIME
969 self.list = []
969 self.list = []
970
970
971 self.done = 0
971 self.done = 0
972
972
973 self.zeroconf.addListener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
973 self.zeroconf.addListener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
974 self.start()
974 self.start()
975
975
976 def updateRecord(self, zeroconf, now, record):
976 def updateRecord(self, zeroconf, now, record):
977 """Callback invoked by Zeroconf when new information arrives.
977 """Callback invoked by Zeroconf when new information arrives.
978
978
979 Updates information required by browser in the Zeroconf cache."""
979 Updates information required by browser in the Zeroconf cache."""
980 if record.type == _TYPE_PTR and record.name == self.type:
980 if record.type == _TYPE_PTR and record.name == self.type:
981 expired = record.isExpired(now)
981 expired = record.isExpired(now)
982 try:
982 try:
983 oldrecord = self.services[record.alias.lower()]
983 oldrecord = self.services[record.alias.lower()]
984 if not expired:
984 if not expired:
985 oldrecord.resetTTL(record)
985 oldrecord.resetTTL(record)
986 else:
986 else:
987 del(self.services[record.alias.lower()])
987 del(self.services[record.alias.lower()])
988 callback = lambda x: self.listener.removeService(x, self.type, record.alias)
988 callback = lambda x: self.listener.removeService(x, self.type, record.alias)
989 self.list.append(callback)
989 self.list.append(callback)
990 return
990 return
991 except Exception:
991 except Exception:
992 if not expired:
992 if not expired:
993 self.services[record.alias.lower()] = record
993 self.services[record.alias.lower()] = record
994 callback = lambda x: self.listener.addService(x, self.type, record.alias)
994 callback = lambda x: self.listener.addService(x, self.type, record.alias)
995 self.list.append(callback)
995 self.list.append(callback)
996
996
997 expires = record.getExpirationTime(75)
997 expires = record.getExpirationTime(75)
998 if expires < self.nextTime:
998 if expires < self.nextTime:
999 self.nextTime = expires
999 self.nextTime = expires
1000
1000
1001 def cancel(self):
1001 def cancel(self):
1002 self.done = 1
1002 self.done = 1
1003 self.zeroconf.notifyAll()
1003 self.zeroconf.notifyAll()
1004
1004
1005 def run(self):
1005 def run(self):
1006 while True:
1006 while True:
1007 event = None
1007 event = None
1008 now = currentTimeMillis()
1008 now = currentTimeMillis()
1009 if len(self.list) == 0 and self.nextTime > now:
1009 if len(self.list) == 0 and self.nextTime > now:
1010 self.zeroconf.wait(self.nextTime - now)
1010 self.zeroconf.wait(self.nextTime - now)
1011 if globals()['_GLOBAL_DONE'] or self.done:
1011 if globals()['_GLOBAL_DONE'] or self.done:
1012 return
1012 return
1013 now = currentTimeMillis()
1013 now = currentTimeMillis()
1014
1014
1015 if self.nextTime <= now:
1015 if self.nextTime <= now:
1016 out = DNSOutgoing(_FLAGS_QR_QUERY)
1016 out = DNSOutgoing(_FLAGS_QR_QUERY)
1017 out.addQuestion(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
1017 out.addQuestion(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
1018 for record in self.services.values():
1018 for record in self.services.values():
1019 if not record.isExpired(now):
1019 if not record.isExpired(now):
1020 out.addAnswerAtTime(record, now)
1020 out.addAnswerAtTime(record, now)
1021 self.zeroconf.send(out)
1021 self.zeroconf.send(out)
1022 self.nextTime = now + self.delay
1022 self.nextTime = now + self.delay
1023 self.delay = min(20 * 1000, self.delay * 2)
1023 self.delay = min(20 * 1000, self.delay * 2)
1024
1024
1025 if len(self.list) > 0:
1025 if len(self.list) > 0:
1026 event = self.list.pop(0)
1026 event = self.list.pop(0)
1027
1027
1028 if event is not None:
1028 if event is not None:
1029 event(self.zeroconf)
1029 event(self.zeroconf)
1030
1030
1031
1031
1032 class ServiceInfo(object):
1032 class ServiceInfo(object):
1033 """Service information"""
1033 """Service information"""
1034
1034
1035 def __init__(self, type, name, address=None, port=None, weight=0, priority=0, properties=None, server=None):
1035 def __init__(self, type, name, address=None, port=None, weight=0, priority=0, properties=None, server=None):
1036 """Create a service description.
1036 """Create a service description.
1037
1037
1038 type: fully qualified service type name
1038 type: fully qualified service type name
1039 name: fully qualified service name
1039 name: fully qualified service name
1040 address: IP address as unsigned short, network byte order
1040 address: IP address as unsigned short, network byte order
1041 port: port that the service runs on
1041 port: port that the service runs on
1042 weight: weight of the service
1042 weight: weight of the service
1043 priority: priority of the service
1043 priority: priority of the service
1044 properties: dictionary of properties (or a string holding the bytes for the text field)
1044 properties: dictionary of properties (or a string holding the bytes for the text field)
1045 server: fully qualified name for service host (defaults to name)"""
1045 server: fully qualified name for service host (defaults to name)"""
1046
1046
1047 if not name.endswith(type):
1047 if not name.endswith(type):
1048 raise BadTypeInNameException
1048 raise BadTypeInNameException
1049 self.type = type
1049 self.type = type
1050 self.name = name
1050 self.name = name
1051 self.address = address
1051 self.address = address
1052 self.port = port
1052 self.port = port
1053 self.weight = weight
1053 self.weight = weight
1054 self.priority = priority
1054 self.priority = priority
1055 if server:
1055 if server:
1056 self.server = server
1056 self.server = server
1057 else:
1057 else:
1058 self.server = name
1058 self.server = name
1059 self.setProperties(properties)
1059 self.setProperties(properties)
1060
1060
1061 def setProperties(self, properties):
1061 def setProperties(self, properties):
1062 """Sets properties and text of this info from a dictionary"""
1062 """Sets properties and text of this info from a dictionary"""
1063 if isinstance(properties, dict):
1063 if isinstance(properties, dict):
1064 self.properties = properties
1064 self.properties = properties
1065 list = []
1065 list = []
1066 result = ''
1066 result = ''
1067 for key in properties:
1067 for key in properties:
1068 value = properties[key]
1068 value = properties[key]
1069 if value is None:
1069 if value is None:
1070 suffix = ''
1070 suffix = ''
1071 elif isinstance(value, str):
1071 elif isinstance(value, str):
1072 suffix = value
1072 suffix = value
1073 elif isinstance(value, int):
1073 elif isinstance(value, int):
1074 if value:
1074 if value:
1075 suffix = 'true'
1075 suffix = 'true'
1076 else:
1076 else:
1077 suffix = 'false'
1077 suffix = 'false'
1078 else:
1078 else:
1079 suffix = ''
1079 suffix = ''
1080 list.append('='.join((key, suffix)))
1080 list.append('='.join((key, suffix)))
1081 for item in list:
1081 for item in list:
1082 result = ''.join((result, struct.pack('!c', chr(len(item))), item))
1082 result = ''.join((result, struct.pack('!c', chr(len(item))), item))
1083 self.text = result
1083 self.text = result
1084 else:
1084 else:
1085 self.text = properties
1085 self.text = properties
1086
1086
1087 def setText(self, text):
1087 def setText(self, text):
1088 """Sets properties and text given a text field"""
1088 """Sets properties and text given a text field"""
1089 self.text = text
1089 self.text = text
1090 try:
1090 try:
1091 result = {}
1091 result = {}
1092 end = len(text)
1092 end = len(text)
1093 index = 0
1093 index = 0
1094 strs = []
1094 strs = []
1095 while index < end:
1095 while index < end:
1096 length = ord(text[index])
1096 length = ord(text[index])
1097 index += 1
1097 index += 1
1098 strs.append(text[index:index+length])
1098 strs.append(text[index:index+length])
1099 index += length
1099 index += length
1100
1100
1101 for s in strs:
1101 for s in strs:
1102 eindex = s.find('=')
1102 eindex = s.find('=')
1103 if eindex == -1:
1103 if eindex == -1:
1104 # No equals sign at all
1104 # No equals sign at all
1105 key = s
1105 key = s
1106 value = 0
1106 value = 0
1107 else:
1107 else:
1108 key = s[:eindex]
1108 key = s[:eindex]
1109 value = s[eindex+1:]
1109 value = s[eindex+1:]
1110 if value == 'true':
1110 if value == 'true':
1111 value = 1
1111 value = 1
1112 elif value == 'false' or not value:
1112 elif value == 'false' or not value:
1113 value = 0
1113 value = 0
1114
1114
1115 # Only update non-existent properties
1115 # Only update non-existent properties
1116 if key and result.get(key) == None:
1116 if key and result.get(key) == None:
1117 result[key] = value
1117 result[key] = value
1118
1118
1119 self.properties = result
1119 self.properties = result
1120 except Exception:
1120 except Exception:
1121 traceback.print_exc()
1121 traceback.print_exc()
1122 self.properties = None
1122 self.properties = None
1123
1123
1124 def getType(self):
1124 def getType(self):
1125 """Type accessor"""
1125 """Type accessor"""
1126 return self.type
1126 return self.type
1127
1127
1128 def getName(self):
1128 def getName(self):
1129 """Name accessor"""
1129 """Name accessor"""
1130 if self.type is not None and self.name.endswith("." + self.type):
1130 if self.type is not None and self.name.endswith("." + self.type):
1131 return self.name[:len(self.name) - len(self.type) - 1]
1131 return self.name[:len(self.name) - len(self.type) - 1]
1132 return self.name
1132 return self.name
1133
1133
1134 def getAddress(self):
1134 def getAddress(self):
1135 """Address accessor"""
1135 """Address accessor"""
1136 return self.address
1136 return self.address
1137
1137
1138 def getPort(self):
1138 def getPort(self):
1139 """Port accessor"""
1139 """Port accessor"""
1140 return self.port
1140 return self.port
1141
1141
1142 def getPriority(self):
1142 def getPriority(self):
1143 """Priority accessor"""
1143 """Priority accessor"""
1144 return self.priority
1144 return self.priority
1145
1145
1146 def getWeight(self):
1146 def getWeight(self):
1147 """Weight accessor"""
1147 """Weight accessor"""
1148 return self.weight
1148 return self.weight
1149
1149
1150 def getProperties(self):
1150 def getProperties(self):
1151 """Properties accessor"""
1151 """Properties accessor"""
1152 return self.properties
1152 return self.properties
1153
1153
1154 def getText(self):
1154 def getText(self):
1155 """Text accessor"""
1155 """Text accessor"""
1156 return self.text
1156 return self.text
1157
1157
1158 def getServer(self):
1158 def getServer(self):
1159 """Server accessor"""
1159 """Server accessor"""
1160 return self.server
1160 return self.server
1161
1161
1162 def updateRecord(self, zeroconf, now, record):
1162 def updateRecord(self, zeroconf, now, record):
1163 """Updates service information from a DNS record"""
1163 """Updates service information from a DNS record"""
1164 if record is not None and not record.isExpired(now):
1164 if record is not None and not record.isExpired(now):
1165 if record.type == _TYPE_A:
1165 if record.type == _TYPE_A:
1166 #if record.name == self.name:
1166 #if record.name == self.name:
1167 if record.name == self.server:
1167 if record.name == self.server:
1168 self.address = record.address
1168 self.address = record.address
1169 elif record.type == _TYPE_SRV:
1169 elif record.type == _TYPE_SRV:
1170 if record.name == self.name:
1170 if record.name == self.name:
1171 self.server = record.server
1171 self.server = record.server
1172 self.port = record.port
1172 self.port = record.port
1173 self.weight = record.weight
1173 self.weight = record.weight
1174 self.priority = record.priority
1174 self.priority = record.priority
1175 #self.address = None
1175 #self.address = None
1176 self.updateRecord(zeroconf, now, zeroconf.cache.getByDetails(self.server, _TYPE_A, _CLASS_IN))
1176 self.updateRecord(zeroconf, now, zeroconf.cache.getByDetails(self.server, _TYPE_A, _CLASS_IN))
1177 elif record.type == _TYPE_TXT:
1177 elif record.type == _TYPE_TXT:
1178 if record.name == self.name:
1178 if record.name == self.name:
1179 self.setText(record.text)
1179 self.setText(record.text)
1180
1180
1181 def request(self, zeroconf, timeout):
1181 def request(self, zeroconf, timeout):
1182 """Returns true if the service could be discovered on the
1182 """Returns true if the service could be discovered on the
1183 network, and updates this object with details discovered.
1183 network, and updates this object with details discovered.
1184 """
1184 """
1185 now = currentTimeMillis()
1185 now = currentTimeMillis()
1186 delay = _LISTENER_TIME
1186 delay = _LISTENER_TIME
1187 next = now + delay
1187 next = now + delay
1188 last = now + timeout
1188 last = now + timeout
1189 result = 0
1189 result = 0
1190 try:
1190 try:
1191 zeroconf.addListener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN))
1191 zeroconf.addListener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN))
1192 while self.server is None or self.address is None or self.text is None:
1192 while self.server is None or self.address is None or self.text is None:
1193 if last <= now:
1193 if last <= now:
1194 return 0
1194 return 0
1195 if next <= now:
1195 if next <= now:
1196 out = DNSOutgoing(_FLAGS_QR_QUERY)
1196 out = DNSOutgoing(_FLAGS_QR_QUERY)
1197 out.addQuestion(DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN))
1197 out.addQuestion(DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN))
1198 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.name, _TYPE_SRV, _CLASS_IN), now)
1198 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.name, _TYPE_SRV, _CLASS_IN), now)
1199 out.addQuestion(DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN))
1199 out.addQuestion(DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN))
1200 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.name, _TYPE_TXT, _CLASS_IN), now)
1200 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.name, _TYPE_TXT, _CLASS_IN), now)
1201 if self.server is not None:
1201 if self.server is not None:
1202 out.addQuestion(DNSQuestion(self.server, _TYPE_A, _CLASS_IN))
1202 out.addQuestion(DNSQuestion(self.server, _TYPE_A, _CLASS_IN))
1203 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.server, _TYPE_A, _CLASS_IN), now)
1203 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.server, _TYPE_A, _CLASS_IN), now)
1204 zeroconf.send(out)
1204 zeroconf.send(out)
1205 next = now + delay
1205 next = now + delay
1206 delay = delay * 2
1206 delay = delay * 2
1207
1207
1208 zeroconf.wait(min(next, last) - now)
1208 zeroconf.wait(min(next, last) - now)
1209 now = currentTimeMillis()
1209 now = currentTimeMillis()
1210 result = 1
1210 result = 1
1211 finally:
1211 finally:
1212 zeroconf.removeListener(self)
1212 zeroconf.removeListener(self)
1213
1213
1214 return result
1214 return result
1215
1215
1216 def __eq__(self, other):
1216 def __eq__(self, other):
1217 """Tests equality of service name"""
1217 """Tests equality of service name"""
1218 if isinstance(other, ServiceInfo):
1218 if isinstance(other, ServiceInfo):
1219 return other.name == self.name
1219 return other.name == self.name
1220 return 0
1220 return 0
1221
1221
1222 def __ne__(self, other):
1222 def __ne__(self, other):
1223 """Non-equality test"""
1223 """Non-equality test"""
1224 return not self.__eq__(other)
1224 return not self.__eq__(other)
1225
1225
1226 def __repr__(self):
1226 def __repr__(self):
1227 """String representation"""
1227 """String representation"""
1228 result = "service[%s,%s:%s," % (self.name, socket.inet_ntoa(self.getAddress()), self.port)
1228 result = "service[%s,%s:%s," % (self.name, socket.inet_ntoa(self.getAddress()), self.port)
1229 if self.text is None:
1229 if self.text is None:
1230 result += "None"
1230 result += "None"
1231 else:
1231 else:
1232 if len(self.text) < 20:
1232 if len(self.text) < 20:
1233 result += self.text
1233 result += self.text
1234 else:
1234 else:
1235 result += self.text[:17] + "..."
1235 result += self.text[:17] + "..."
1236 result += "]"
1236 result += "]"
1237 return result
1237 return result
1238
1238
1239
1239
1240 class Zeroconf(object):
1240 class Zeroconf(object):
1241 """Implementation of Zeroconf Multicast DNS Service Discovery
1241 """Implementation of Zeroconf Multicast DNS Service Discovery
1242
1242
1243 Supports registration, unregistration, queries and browsing.
1243 Supports registration, unregistration, queries and browsing.
1244 """
1244 """
1245 def __init__(self, bindaddress=None):
1245 def __init__(self, bindaddress=None):
1246 """Creates an instance of the Zeroconf class, establishing
1246 """Creates an instance of the Zeroconf class, establishing
1247 multicast communications, listening and reaping threads."""
1247 multicast communications, listening and reaping threads."""
1248 globals()['_GLOBAL_DONE'] = 0
1248 globals()['_GLOBAL_DONE'] = 0
1249 if bindaddress is None:
1249 if bindaddress is None:
1250 self.intf = socket.gethostbyname(socket.gethostname())
1250 self.intf = socket.gethostbyname(socket.gethostname())
1251 else:
1251 else:
1252 self.intf = bindaddress
1252 self.intf = bindaddress
1253 self.group = ('', _MDNS_PORT)
1253 self.group = ('', _MDNS_PORT)
1254 self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1254 self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1255 try:
1255 try:
1256 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1256 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1257 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1257 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1258 except Exception:
1258 except Exception:
1259 # SO_REUSEADDR should be equivalent to SO_REUSEPORT for
1259 # SO_REUSEADDR should be equivalent to SO_REUSEPORT for
1260 # multicast UDP sockets (p 731, "TCP/IP Illustrated,
1260 # multicast UDP sockets (p 731, "TCP/IP Illustrated,
1261 # Volume 2"), but some BSD-derived systems require
1261 # Volume 2"), but some BSD-derived systems require
1262 # SO_REUSEPORT to be specified explicitly. Also, not all
1262 # SO_REUSEPORT to be specified explicitly. Also, not all
1263 # versions of Python have SO_REUSEPORT available. So
1263 # versions of Python have SO_REUSEPORT available. So
1264 # if you're on a BSD-based system, and haven't upgraded
1264 # if you're on a BSD-based system, and haven't upgraded
1265 # to Python 2.3 yet, you may find this library doesn't
1265 # to Python 2.3 yet, you may find this library doesn't
1266 # work as expected.
1266 # work as expected.
1267 #
1267 #
1268 pass
1268 pass
1269 self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 255)
1269 self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 255)
1270 self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1)
1270 self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1)
1271 try:
1271 try:
1272 self.socket.bind(self.group)
1272 self.socket.bind(self.group)
1273 except Exception:
1273 except Exception:
1274 # Some versions of linux raise an exception even though
1274 # Some versions of linux raise an exception even though
1275 # the SO_REUSE* options have been set, so ignore it
1275 # the SO_REUSE* options have been set, so ignore it
1276 #
1276 #
1277 pass
1277 pass
1278 #self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(self.intf) + socket.inet_aton('0.0.0.0'))
1279 self.socket.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0'))
1278 self.socket.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0'))
1280
1279
1281 self.listeners = []
1280 self.listeners = []
1282 self.browsers = []
1281 self.browsers = []
1283 self.services = {}
1282 self.services = {}
1284 self.servicetypes = {}
1283 self.servicetypes = {}
1285
1284
1286 self.cache = DNSCache()
1285 self.cache = DNSCache()
1287
1286
1288 self.condition = threading.Condition()
1287 self.condition = threading.Condition()
1289
1288
1290 self.engine = Engine(self)
1289 self.engine = Engine(self)
1291 self.listener = Listener(self)
1290 self.listener = Listener(self)
1292 self.reaper = Reaper(self)
1291 self.reaper = Reaper(self)
1293
1292
1294 def isLoopback(self):
1293 def isLoopback(self):
1295 return self.intf.startswith("127.0.0.1")
1294 return self.intf.startswith("127.0.0.1")
1296
1295
1297 def isLinklocal(self):
1296 def isLinklocal(self):
1298 return self.intf.startswith("169.254.")
1297 return self.intf.startswith("169.254.")
1299
1298
1300 def wait(self, timeout):
1299 def wait(self, timeout):
1301 """Calling thread waits for a given number of milliseconds or
1300 """Calling thread waits for a given number of milliseconds or
1302 until notified."""
1301 until notified."""
1303 self.condition.acquire()
1302 self.condition.acquire()
1304 self.condition.wait(timeout/1000)
1303 self.condition.wait(timeout/1000)
1305 self.condition.release()
1304 self.condition.release()
1306
1305
1307 def notifyAll(self):
1306 def notifyAll(self):
1308 """Notifies all waiting threads"""
1307 """Notifies all waiting threads"""
1309 self.condition.acquire()
1308 self.condition.acquire()
1310 self.condition.notifyAll()
1309 self.condition.notifyAll()
1311 self.condition.release()
1310 self.condition.release()
1312
1311
1313 def getServiceInfo(self, type, name, timeout=3000):
1312 def getServiceInfo(self, type, name, timeout=3000):
1314 """Returns network's service information for a particular
1313 """Returns network's service information for a particular
1315 name and type, or None if no service matches by the timeout,
1314 name and type, or None if no service matches by the timeout,
1316 which defaults to 3 seconds."""
1315 which defaults to 3 seconds."""
1317 info = ServiceInfo(type, name)
1316 info = ServiceInfo(type, name)
1318 if info.request(self, timeout):
1317 if info.request(self, timeout):
1319 return info
1318 return info
1320 return None
1319 return None
1321
1320
1322 def addServiceListener(self, type, listener):
1321 def addServiceListener(self, type, listener):
1323 """Adds a listener for a particular service type. This object
1322 """Adds a listener for a particular service type. This object
1324 will then have its updateRecord method called when information
1323 will then have its updateRecord method called when information
1325 arrives for that type."""
1324 arrives for that type."""
1326 self.removeServiceListener(listener)
1325 self.removeServiceListener(listener)
1327 self.browsers.append(ServiceBrowser(self, type, listener))
1326 self.browsers.append(ServiceBrowser(self, type, listener))
1328
1327
1329 def removeServiceListener(self, listener):
1328 def removeServiceListener(self, listener):
1330 """Removes a listener from the set that is currently listening."""
1329 """Removes a listener from the set that is currently listening."""
1331 for browser in self.browsers:
1330 for browser in self.browsers:
1332 if browser.listener == listener:
1331 if browser.listener == listener:
1333 browser.cancel()
1332 browser.cancel()
1334 del(browser)
1333 del(browser)
1335
1334
1336 def registerService(self, info, ttl=_DNS_TTL):
1335 def registerService(self, info, ttl=_DNS_TTL):
1337 """Registers service information to the network with a default TTL
1336 """Registers service information to the network with a default TTL
1338 of 60 seconds. Zeroconf will then respond to requests for
1337 of 60 seconds. Zeroconf will then respond to requests for
1339 information for that service. The name of the service may be
1338 information for that service. The name of the service may be
1340 changed if needed to make it unique on the network."""
1339 changed if needed to make it unique on the network."""
1341 self.checkService(info)
1340 self.checkService(info)
1342 self.services[info.name.lower()] = info
1341 self.services[info.name.lower()] = info
1343 if self.servicetypes.has_key(info.type):
1342 if self.servicetypes.has_key(info.type):
1344 self.servicetypes[info.type]+=1
1343 self.servicetypes[info.type]+=1
1345 else:
1344 else:
1346 self.servicetypes[info.type]=1
1345 self.servicetypes[info.type]=1
1347 now = currentTimeMillis()
1346 now = currentTimeMillis()
1348 nextTime = now
1347 nextTime = now
1349 i = 0
1348 i = 0
1350 while i < 3:
1349 while i < 3:
1351 if now < nextTime:
1350 if now < nextTime:
1352 self.wait(nextTime - now)
1351 self.wait(nextTime - now)
1353 now = currentTimeMillis()
1352 now = currentTimeMillis()
1354 continue
1353 continue
1355 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1354 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1356 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, ttl, info.name), 0)
1355 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, ttl, info.name), 0)
1357 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, ttl, info.priority, info.weight, info.port, info.server), 0)
1356 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, ttl, info.priority, info.weight, info.port, info.server), 0)
1358 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, ttl, info.text), 0)
1357 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, ttl, info.text), 0)
1359 if info.address:
1358 if info.address:
1360 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, ttl, info.address), 0)
1359 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, ttl, info.address), 0)
1361 self.send(out)
1360 self.send(out)
1362 i += 1
1361 i += 1
1363 nextTime += _REGISTER_TIME
1362 nextTime += _REGISTER_TIME
1364
1363
1365 def unregisterService(self, info):
1364 def unregisterService(self, info):
1366 """Unregister a service."""
1365 """Unregister a service."""
1367 try:
1366 try:
1368 del(self.services[info.name.lower()])
1367 del(self.services[info.name.lower()])
1369 if self.servicetypes[info.type]>1:
1368 if self.servicetypes[info.type]>1:
1370 self.servicetypes[info.type]-=1
1369 self.servicetypes[info.type]-=1
1371 else:
1370 else:
1372 del self.servicetypes[info.type]
1371 del self.servicetypes[info.type]
1373 except KeyError:
1372 except KeyError:
1374 pass
1373 pass
1375 now = currentTimeMillis()
1374 now = currentTimeMillis()
1376 nextTime = now
1375 nextTime = now
1377 i = 0
1376 i = 0
1378 while i < 3:
1377 while i < 3:
1379 if now < nextTime:
1378 if now < nextTime:
1380 self.wait(nextTime - now)
1379 self.wait(nextTime - now)
1381 now = currentTimeMillis()
1380 now = currentTimeMillis()
1382 continue
1381 continue
1383 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1382 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1384 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
1383 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
1385 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0, info.priority, info.weight, info.port, info.name), 0)
1384 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0, info.priority, info.weight, info.port, info.name), 0)
1386 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
1385 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
1387 if info.address:
1386 if info.address:
1388 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0)
1387 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0)
1389 self.send(out)
1388 self.send(out)
1390 i += 1
1389 i += 1
1391 nextTime += _UNREGISTER_TIME
1390 nextTime += _UNREGISTER_TIME
1392
1391
1393 def unregisterAllServices(self):
1392 def unregisterAllServices(self):
1394 """Unregister all registered services."""
1393 """Unregister all registered services."""
1395 if len(self.services) > 0:
1394 if len(self.services) > 0:
1396 now = currentTimeMillis()
1395 now = currentTimeMillis()
1397 nextTime = now
1396 nextTime = now
1398 i = 0
1397 i = 0
1399 while i < 3:
1398 while i < 3:
1400 if now < nextTime:
1399 if now < nextTime:
1401 self.wait(nextTime - now)
1400 self.wait(nextTime - now)
1402 now = currentTimeMillis()
1401 now = currentTimeMillis()
1403 continue
1402 continue
1404 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1403 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1405 for info in self.services.values():
1404 for info in self.services.values():
1406 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
1405 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
1407 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0, info.priority, info.weight, info.port, info.server), 0)
1406 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0, info.priority, info.weight, info.port, info.server), 0)
1408 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
1407 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
1409 if info.address:
1408 if info.address:
1410 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0)
1409 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0)
1411 self.send(out)
1410 self.send(out)
1412 i += 1
1411 i += 1
1413 nextTime += _UNREGISTER_TIME
1412 nextTime += _UNREGISTER_TIME
1414
1413
1415 def checkService(self, info):
1414 def checkService(self, info):
1416 """Checks the network for a unique service name, modifying the
1415 """Checks the network for a unique service name, modifying the
1417 ServiceInfo passed in if it is not unique."""
1416 ServiceInfo passed in if it is not unique."""
1418 now = currentTimeMillis()
1417 now = currentTimeMillis()
1419 nextTime = now
1418 nextTime = now
1420 i = 0
1419 i = 0
1421 while i < 3:
1420 while i < 3:
1422 for record in self.cache.entriesWithName(info.type):
1421 for record in self.cache.entriesWithName(info.type):
1423 if record.type == _TYPE_PTR and not record.isExpired(now) and record.alias == info.name:
1422 if record.type == _TYPE_PTR and not record.isExpired(now) and record.alias == info.name:
1424 if (info.name.find('.') < 0):
1423 if (info.name.find('.') < 0):
1425 info.name = info.name + ".[" + info.address + ":" + info.port + "]." + info.type
1424 info.name = info.name + ".[" + info.address + ":" + info.port + "]." + info.type
1426 self.checkService(info)
1425 self.checkService(info)
1427 return
1426 return
1428 raise NonUniqueNameException
1427 raise NonUniqueNameException
1429 if now < nextTime:
1428 if now < nextTime:
1430 self.wait(nextTime - now)
1429 self.wait(nextTime - now)
1431 now = currentTimeMillis()
1430 now = currentTimeMillis()
1432 continue
1431 continue
1433 out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
1432 out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
1434 self.debug = out
1433 self.debug = out
1435 out.addQuestion(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN))
1434 out.addQuestion(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN))
1436 out.addAuthoritativeAnswer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, _DNS_TTL, info.name))
1435 out.addAuthoritativeAnswer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, _DNS_TTL, info.name))
1437 self.send(out)
1436 self.send(out)
1438 i += 1
1437 i += 1
1439 nextTime += _CHECK_TIME
1438 nextTime += _CHECK_TIME
1440
1439
1441 def addListener(self, listener, question):
1440 def addListener(self, listener, question):
1442 """Adds a listener for a given question. The listener will have
1441 """Adds a listener for a given question. The listener will have
1443 its updateRecord method called when information is available to
1442 its updateRecord method called when information is available to
1444 answer the question."""
1443 answer the question."""
1445 now = currentTimeMillis()
1444 now = currentTimeMillis()
1446 self.listeners.append(listener)
1445 self.listeners.append(listener)
1447 if question is not None:
1446 if question is not None:
1448 for record in self.cache.entriesWithName(question.name):
1447 for record in self.cache.entriesWithName(question.name):
1449 if question.answeredBy(record) and not record.isExpired(now):
1448 if question.answeredBy(record) and not record.isExpired(now):
1450 listener.updateRecord(self, now, record)
1449 listener.updateRecord(self, now, record)
1451 self.notifyAll()
1450 self.notifyAll()
1452
1451
1453 def removeListener(self, listener):
1452 def removeListener(self, listener):
1454 """Removes a listener."""
1453 """Removes a listener."""
1455 try:
1454 try:
1456 self.listeners.remove(listener)
1455 self.listeners.remove(listener)
1457 self.notifyAll()
1456 self.notifyAll()
1458 except Exception:
1457 except Exception:
1459 pass
1458 pass
1460
1459
1461 def updateRecord(self, now, rec):
1460 def updateRecord(self, now, rec):
1462 """Used to notify listeners of new information that has updated
1461 """Used to notify listeners of new information that has updated
1463 a record."""
1462 a record."""
1464 for listener in self.listeners:
1463 for listener in self.listeners:
1465 listener.updateRecord(self, now, rec)
1464 listener.updateRecord(self, now, rec)
1466 self.notifyAll()
1465 self.notifyAll()
1467
1466
1468 def handleResponse(self, msg):
1467 def handleResponse(self, msg):
1469 """Deal with incoming response packets. All answers
1468 """Deal with incoming response packets. All answers
1470 are held in the cache, and listeners are notified."""
1469 are held in the cache, and listeners are notified."""
1471 now = currentTimeMillis()
1470 now = currentTimeMillis()
1472 for record in msg.answers:
1471 for record in msg.answers:
1473 expired = record.isExpired(now)
1472 expired = record.isExpired(now)
1474 if record in self.cache.entries():
1473 if record in self.cache.entries():
1475 if expired:
1474 if expired:
1476 self.cache.remove(record)
1475 self.cache.remove(record)
1477 else:
1476 else:
1478 entry = self.cache.get(record)
1477 entry = self.cache.get(record)
1479 if entry is not None:
1478 if entry is not None:
1480 entry.resetTTL(record)
1479 entry.resetTTL(record)
1481 record = entry
1480 record = entry
1482 else:
1481 else:
1483 self.cache.add(record)
1482 self.cache.add(record)
1484
1483
1485 self.updateRecord(now, record)
1484 self.updateRecord(now, record)
1486
1485
1487 def handleQuery(self, msg, addr, port):
1486 def handleQuery(self, msg, addr, port):
1488 """Deal with incoming query packets. Provides a response if
1487 """Deal with incoming query packets. Provides a response if
1489 possible."""
1488 possible."""
1490 out = None
1489 out = None
1491
1490
1492 # Support unicast client responses
1491 # Support unicast client responses
1493 #
1492 #
1494 if port != _MDNS_PORT:
1493 if port != _MDNS_PORT:
1495 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, 0)
1494 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, 0)
1496 for question in msg.questions:
1495 for question in msg.questions:
1497 out.addQuestion(question)
1496 out.addQuestion(question)
1498
1497
1499 for question in msg.questions:
1498 for question in msg.questions:
1500 if question.type == _TYPE_PTR:
1499 if question.type == _TYPE_PTR:
1501 if question.name == "_services._dns-sd._udp.local.":
1500 if question.name == "_services._dns-sd._udp.local.":
1502 for stype in self.servicetypes.keys():
1501 for stype in self.servicetypes.keys():
1503 if out is None:
1502 if out is None:
1504 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1503 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1505 out.addAnswer(msg, DNSPointer("_services._dns-sd._udp.local.", _TYPE_PTR, _CLASS_IN, _DNS_TTL, stype))
1504 out.addAnswer(msg, DNSPointer("_services._dns-sd._udp.local.", _TYPE_PTR, _CLASS_IN, _DNS_TTL, stype))
1506 for service in self.services.values():
1505 for service in self.services.values():
1507 if question.name == service.type:
1506 if question.name == service.type:
1508 if out is None:
1507 if out is None:
1509 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1508 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1510 out.addAnswer(msg, DNSPointer(service.type, _TYPE_PTR, _CLASS_IN, _DNS_TTL, service.name))
1509 out.addAnswer(msg, DNSPointer(service.type, _TYPE_PTR, _CLASS_IN, _DNS_TTL, service.name))
1511 else:
1510 else:
1512 try:
1511 try:
1513 if out is None:
1512 if out is None:
1514 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1513 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1515
1514
1516 # Answer A record queries for any service addresses we know
1515 # Answer A record queries for any service addresses we know
1517 if question.type == _TYPE_A or question.type == _TYPE_ANY:
1516 if question.type == _TYPE_A or question.type == _TYPE_ANY:
1518 for service in self.services.values():
1517 for service in self.services.values():
1519 if service.server == question.name.lower():
1518 if service.server == question.name.lower():
1520 out.addAnswer(msg, DNSAddress(question.name, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.address))
1519 out.addAnswer(msg, DNSAddress(question.name, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.address))
1521
1520
1522 service = self.services.get(question.name.lower(), None)
1521 service = self.services.get(question.name.lower(), None)
1523 if not service: continue
1522 if not service: continue
1524
1523
1525 if question.type == _TYPE_SRV or question.type == _TYPE_ANY:
1524 if question.type == _TYPE_SRV or question.type == _TYPE_ANY:
1526 out.addAnswer(msg, DNSService(question.name, _TYPE_SRV, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.priority, service.weight, service.port, service.server))
1525 out.addAnswer(msg, DNSService(question.name, _TYPE_SRV, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.priority, service.weight, service.port, service.server))
1527 if question.type == _TYPE_TXT or question.type == _TYPE_ANY:
1526 if question.type == _TYPE_TXT or question.type == _TYPE_ANY:
1528 out.addAnswer(msg, DNSText(question.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.text))
1527 out.addAnswer(msg, DNSText(question.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.text))
1529 if question.type == _TYPE_SRV:
1528 if question.type == _TYPE_SRV:
1530 out.addAdditionalAnswer(DNSAddress(service.server, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.address))
1529 out.addAdditionalAnswer(DNSAddress(service.server, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.address))
1531 except Exception:
1530 except Exception:
1532 traceback.print_exc()
1531 traceback.print_exc()
1533
1532
1534 if out is not None and out.answers:
1533 if out is not None and out.answers:
1535 out.id = msg.id
1534 out.id = msg.id
1536 self.send(out, addr, port)
1535 self.send(out, addr, port)
1537
1536
1538 def send(self, out, addr = _MDNS_ADDR, port = _MDNS_PORT):
1537 def send(self, out, addr = _MDNS_ADDR, port = _MDNS_PORT):
1539 """Sends an outgoing packet."""
1538 """Sends an outgoing packet."""
1540 # This is a quick test to see if we can parse the packets we generate
1539 # This is a quick test to see if we can parse the packets we generate
1541 #temp = DNSIncoming(out.packet())
1540 #temp = DNSIncoming(out.packet())
1542 try:
1541 try:
1543 self.socket.sendto(out.packet(), 0, (addr, port))
1542 self.socket.sendto(out.packet(), 0, (addr, port))
1544 except Exception:
1543 except Exception:
1545 # Ignore this, it may be a temporary loss of network connection
1544 # Ignore this, it may be a temporary loss of network connection
1546 pass
1545 pass
1547
1546
1548 def close(self):
1547 def close(self):
1549 """Ends the background threads, and prevent this instance from
1548 """Ends the background threads, and prevent this instance from
1550 servicing further queries."""
1549 servicing further queries."""
1551 if globals()['_GLOBAL_DONE'] == 0:
1550 if globals()['_GLOBAL_DONE'] == 0:
1552 globals()['_GLOBAL_DONE'] = 1
1551 globals()['_GLOBAL_DONE'] = 1
1553 self.notifyAll()
1552 self.notifyAll()
1554 self.engine.notify()
1553 self.engine.notify()
1555 self.unregisterAllServices()
1554 self.unregisterAllServices()
1556 self.socket.setsockopt(socket.SOL_IP, socket.IP_DROP_MEMBERSHIP, socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0'))
1555 self.socket.setsockopt(socket.SOL_IP, socket.IP_DROP_MEMBERSHIP, socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0'))
1557 self.socket.close()
1556 self.socket.close()
1558
1557
1559 # Test a few module features, including service registration, service
1558 # Test a few module features, including service registration, service
1560 # query (for Zoe), and service unregistration.
1559 # query (for Zoe), and service unregistration.
1561
1560
1562 if __name__ == '__main__':
1561 if __name__ == '__main__':
1563 print "Multicast DNS Service Discovery for Python, version", __version__
1562 print "Multicast DNS Service Discovery for Python, version", __version__
1564 r = Zeroconf()
1563 r = Zeroconf()
1565 print "1. Testing registration of a service..."
1564 print "1. Testing registration of a service..."
1566 desc = {'version':'0.10','a':'test value', 'b':'another value'}
1565 desc = {'version':'0.10','a':'test value', 'b':'another value'}
1567 info = ServiceInfo("_http._tcp.local.", "My Service Name._http._tcp.local.", socket.inet_aton("127.0.0.1"), 1234, 0, 0, desc)
1566 info = ServiceInfo("_http._tcp.local.", "My Service Name._http._tcp.local.", socket.inet_aton("127.0.0.1"), 1234, 0, 0, desc)
1568 print " Registering service..."
1567 print " Registering service..."
1569 r.registerService(info)
1568 r.registerService(info)
1570 print " Registration done."
1569 print " Registration done."
1571 print "2. Testing query of service information..."
1570 print "2. Testing query of service information..."
1572 print " Getting ZOE service:", str(r.getServiceInfo("_http._tcp.local.", "ZOE._http._tcp.local."))
1571 print " Getting ZOE service:", str(r.getServiceInfo("_http._tcp.local.", "ZOE._http._tcp.local."))
1573 print " Query done."
1572 print " Query done."
1574 print "3. Testing query of own service..."
1573 print "3. Testing query of own service..."
1575 print " Getting self:", str(r.getServiceInfo("_http._tcp.local.", "My Service Name._http._tcp.local."))
1574 print " Getting self:", str(r.getServiceInfo("_http._tcp.local.", "My Service Name._http._tcp.local."))
1576 print " Query done."
1575 print " Query done."
1577 print "4. Testing unregister of service information..."
1576 print "4. Testing unregister of service information..."
1578 r.unregisterService(info)
1577 r.unregisterService(info)
1579 print " Unregister done."
1578 print " Unregister done."
1580 r.close()
1579 r.close()
1581
1580
1582 # no-check-code
1581 # no-check-code
@@ -1,206 +1,204
1 # manifest.py - manifest revision class for mercurial
1 # manifest.py - manifest revision class for mercurial
2 #
2 #
3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from i18n import _
8 from i18n import _
9 import mdiff, parsers, error, revlog, util
9 import mdiff, parsers, error, revlog, util
10 import array, struct
10 import array, struct
11
11
12 class manifestdict(dict):
12 class manifestdict(dict):
13 def __init__(self, mapping=None, flags=None):
13 def __init__(self, mapping=None, flags=None):
14 if mapping is None:
14 if mapping is None:
15 mapping = {}
15 mapping = {}
16 if flags is None:
16 if flags is None:
17 flags = {}
17 flags = {}
18 dict.__init__(self, mapping)
18 dict.__init__(self, mapping)
19 self._flags = flags
19 self._flags = flags
20 def flags(self, f):
20 def flags(self, f):
21 return self._flags.get(f, "")
21 return self._flags.get(f, "")
22 def withflags(self):
22 def withflags(self):
23 return set(self._flags.keys())
23 return set(self._flags.keys())
24 def set(self, f, flags):
24 def set(self, f, flags):
25 self._flags[f] = flags
25 self._flags[f] = flags
26 def copy(self):
26 def copy(self):
27 return manifestdict(self, dict.copy(self._flags))
27 return manifestdict(self, dict.copy(self._flags))
28
28
29 class manifest(revlog.revlog):
29 class manifest(revlog.revlog):
30 def __init__(self, opener):
30 def __init__(self, opener):
31 self._mancache = None
31 self._mancache = None
32 revlog.revlog.__init__(self, opener, "00manifest.i")
32 revlog.revlog.__init__(self, opener, "00manifest.i")
33
33
34 def parse(self, lines):
34 def parse(self, lines):
35 mfdict = manifestdict()
35 mfdict = manifestdict()
36 parsers.parse_manifest(mfdict, mfdict._flags, lines)
36 parsers.parse_manifest(mfdict, mfdict._flags, lines)
37 return mfdict
37 return mfdict
38
38
39 def readdelta(self, node):
39 def readdelta(self, node):
40 r = self.rev(node)
40 r = self.rev(node)
41 return self.parse(mdiff.patchtext(self.revdiff(self.deltaparent(r), r)))
41 return self.parse(mdiff.patchtext(self.revdiff(self.deltaparent(r), r)))
42
42
43 def readfast(self, node):
43 def readfast(self, node):
44 '''use the faster of readdelta or read'''
44 '''use the faster of readdelta or read'''
45 r = self.rev(node)
45 r = self.rev(node)
46 deltaparent = self.deltaparent(r)
46 deltaparent = self.deltaparent(r)
47 if deltaparent != revlog.nullrev and deltaparent in self.parentrevs(r):
47 if deltaparent != revlog.nullrev and deltaparent in self.parentrevs(r):
48 return self.readdelta(node)
48 return self.readdelta(node)
49 return self.read(node)
49 return self.read(node)
50
50
51 def read(self, node):
51 def read(self, node):
52 if node == revlog.nullid:
52 if node == revlog.nullid:
53 return manifestdict() # don't upset local cache
53 return manifestdict() # don't upset local cache
54 if self._mancache and self._mancache[0] == node:
54 if self._mancache and self._mancache[0] == node:
55 return self._mancache[1]
55 return self._mancache[1]
56 text = self.revision(node)
56 text = self.revision(node)
57 arraytext = array.array('c', text)
57 arraytext = array.array('c', text)
58 mapping = self.parse(text)
58 mapping = self.parse(text)
59 self._mancache = (node, mapping, arraytext)
59 self._mancache = (node, mapping, arraytext)
60 return mapping
60 return mapping
61
61
62 def _search(self, m, s, lo=0, hi=None):
62 def _search(self, m, s, lo=0, hi=None):
63 '''return a tuple (start, end) that says where to find s within m.
63 '''return a tuple (start, end) that says where to find s within m.
64
64
65 If the string is found m[start:end] are the line containing
65 If the string is found m[start:end] are the line containing
66 that string. If start == end the string was not found and
66 that string. If start == end the string was not found and
67 they indicate the proper sorted insertion point. This was
67 they indicate the proper sorted insertion point.
68 taken from bisect_left, and modified to find line start/end as
69 it goes along.
70
68
71 m should be a buffer or a string
69 m should be a buffer or a string
72 s is a string'''
70 s is a string'''
73 def advance(i, c):
71 def advance(i, c):
74 while i < lenm and m[i] != c:
72 while i < lenm and m[i] != c:
75 i += 1
73 i += 1
76 return i
74 return i
77 if not s:
75 if not s:
78 return (lo, lo)
76 return (lo, lo)
79 lenm = len(m)
77 lenm = len(m)
80 if not hi:
78 if not hi:
81 hi = lenm
79 hi = lenm
82 while lo < hi:
80 while lo < hi:
83 mid = (lo + hi) // 2
81 mid = (lo + hi) // 2
84 start = mid
82 start = mid
85 while start > 0 and m[start - 1] != '\n':
83 while start > 0 and m[start - 1] != '\n':
86 start -= 1
84 start -= 1
87 end = advance(start, '\0')
85 end = advance(start, '\0')
88 if m[start:end] < s:
86 if m[start:end] < s:
89 # we know that after the null there are 40 bytes of sha1
87 # we know that after the null there are 40 bytes of sha1
90 # this translates to the bisect lo = mid + 1
88 # this translates to the bisect lo = mid + 1
91 lo = advance(end + 40, '\n') + 1
89 lo = advance(end + 40, '\n') + 1
92 else:
90 else:
93 # this translates to the bisect hi = mid
91 # this translates to the bisect hi = mid
94 hi = start
92 hi = start
95 end = advance(lo, '\0')
93 end = advance(lo, '\0')
96 found = m[lo:end]
94 found = m[lo:end]
97 if s == found:
95 if s == found:
98 # we know that after the null there are 40 bytes of sha1
96 # we know that after the null there are 40 bytes of sha1
99 end = advance(end + 40, '\n')
97 end = advance(end + 40, '\n')
100 return (lo, end + 1)
98 return (lo, end + 1)
101 else:
99 else:
102 return (lo, lo)
100 return (lo, lo)
103
101
104 def find(self, node, f):
102 def find(self, node, f):
105 '''look up entry for a single file efficiently.
103 '''look up entry for a single file efficiently.
106 return (node, flags) pair if found, (None, None) if not.'''
104 return (node, flags) pair if found, (None, None) if not.'''
107 if self._mancache and self._mancache[0] == node:
105 if self._mancache and self._mancache[0] == node:
108 return self._mancache[1].get(f), self._mancache[1].flags(f)
106 return self._mancache[1].get(f), self._mancache[1].flags(f)
109 text = self.revision(node)
107 text = self.revision(node)
110 start, end = self._search(text, f)
108 start, end = self._search(text, f)
111 if start == end:
109 if start == end:
112 return None, None
110 return None, None
113 l = text[start:end]
111 l = text[start:end]
114 f, n = l.split('\0')
112 f, n = l.split('\0')
115 return revlog.bin(n[:40]), n[40:-1]
113 return revlog.bin(n[:40]), n[40:-1]
116
114
117 def add(self, map, transaction, link, p1=None, p2=None,
115 def add(self, map, transaction, link, p1=None, p2=None,
118 changed=None):
116 changed=None):
119 # apply the changes collected during the bisect loop to our addlist
117 # apply the changes collected during the bisect loop to our addlist
120 # return a delta suitable for addrevision
118 # return a delta suitable for addrevision
121 def addlistdelta(addlist, x):
119 def addlistdelta(addlist, x):
122 # start from the bottom up
120 # start from the bottom up
123 # so changes to the offsets don't mess things up.
121 # so changes to the offsets don't mess things up.
124 for start, end, content in reversed(x):
122 for start, end, content in reversed(x):
125 if content:
123 if content:
126 addlist[start:end] = array.array('c', content)
124 addlist[start:end] = array.array('c', content)
127 else:
125 else:
128 del addlist[start:end]
126 del addlist[start:end]
129 return "".join(struct.pack(">lll", start, end, len(content))
127 return "".join(struct.pack(">lll", start, end, len(content))
130 + content for start, end, content in x)
128 + content for start, end, content in x)
131
129
132 def checkforbidden(l):
130 def checkforbidden(l):
133 for f in l:
131 for f in l:
134 if '\n' in f or '\r' in f:
132 if '\n' in f or '\r' in f:
135 raise error.RevlogError(
133 raise error.RevlogError(
136 _("'\\n' and '\\r' disallowed in filenames: %r") % f)
134 _("'\\n' and '\\r' disallowed in filenames: %r") % f)
137
135
138 # if we're using the cache, make sure it is valid and
136 # if we're using the cache, make sure it is valid and
139 # parented by the same node we're diffing against
137 # parented by the same node we're diffing against
140 if not (changed and self._mancache and p1 and self._mancache[0] == p1):
138 if not (changed and self._mancache and p1 and self._mancache[0] == p1):
141 files = sorted(map)
139 files = sorted(map)
142 checkforbidden(files)
140 checkforbidden(files)
143
141
144 # if this is changed to support newlines in filenames,
142 # if this is changed to support newlines in filenames,
145 # be sure to check the templates/ dir again (especially *-raw.tmpl)
143 # be sure to check the templates/ dir again (especially *-raw.tmpl)
146 hex, flags = revlog.hex, map.flags
144 hex, flags = revlog.hex, map.flags
147 text = ''.join("%s\0%s%s\n" % (f, hex(map[f]), flags(f))
145 text = ''.join("%s\0%s%s\n" % (f, hex(map[f]), flags(f))
148 for f in files)
146 for f in files)
149 arraytext = array.array('c', text)
147 arraytext = array.array('c', text)
150 cachedelta = None
148 cachedelta = None
151 else:
149 else:
152 added, removed = changed
150 added, removed = changed
153 addlist = self._mancache[2]
151 addlist = self._mancache[2]
154
152
155 checkforbidden(added)
153 checkforbidden(added)
156 # combine the changed lists into one list for sorting
154 # combine the changed lists into one list for sorting
157 work = [(x, False) for x in added]
155 work = [(x, False) for x in added]
158 work.extend((x, True) for x in removed)
156 work.extend((x, True) for x in removed)
159 # this could use heapq.merge() (from python2.6+) or equivalent
157 # this could use heapq.merge() (from python2.6+) or equivalent
160 # since the lists are already sorted
158 # since the lists are already sorted
161 work.sort()
159 work.sort()
162
160
163 delta = []
161 delta = []
164 dstart = None
162 dstart = None
165 dend = None
163 dend = None
166 dline = [""]
164 dline = [""]
167 start = 0
165 start = 0
168 # zero copy representation of addlist as a buffer
166 # zero copy representation of addlist as a buffer
169 addbuf = util.buffer(addlist)
167 addbuf = util.buffer(addlist)
170
168
171 # start with a readonly loop that finds the offset of
169 # start with a readonly loop that finds the offset of
172 # each line and creates the deltas
170 # each line and creates the deltas
173 for f, todelete in work:
171 for f, todelete in work:
174 # bs will either be the index of the item or the insert point
172 # bs will either be the index of the item or the insert point
175 start, end = self._search(addbuf, f, start)
173 start, end = self._search(addbuf, f, start)
176 if not todelete:
174 if not todelete:
177 l = "%s\0%s%s\n" % (f, revlog.hex(map[f]), map.flags(f))
175 l = "%s\0%s%s\n" % (f, revlog.hex(map[f]), map.flags(f))
178 else:
176 else:
179 if start == end:
177 if start == end:
180 # item we want to delete was not found, error out
178 # item we want to delete was not found, error out
181 raise AssertionError(
179 raise AssertionError(
182 _("failed to remove %s from manifest") % f)
180 _("failed to remove %s from manifest") % f)
183 l = ""
181 l = ""
184 if dstart is not None and dstart <= start and dend >= start:
182 if dstart is not None and dstart <= start and dend >= start:
185 if dend < end:
183 if dend < end:
186 dend = end
184 dend = end
187 if l:
185 if l:
188 dline.append(l)
186 dline.append(l)
189 else:
187 else:
190 if dstart is not None:
188 if dstart is not None:
191 delta.append([dstart, dend, "".join(dline)])
189 delta.append([dstart, dend, "".join(dline)])
192 dstart = start
190 dstart = start
193 dend = end
191 dend = end
194 dline = [l]
192 dline = [l]
195
193
196 if dstart is not None:
194 if dstart is not None:
197 delta.append([dstart, dend, "".join(dline)])
195 delta.append([dstart, dend, "".join(dline)])
198 # apply the delta to the addlist, and get a delta for addrevision
196 # apply the delta to the addlist, and get a delta for addrevision
199 cachedelta = (self.rev(p1), addlistdelta(addlist, delta))
197 cachedelta = (self.rev(p1), addlistdelta(addlist, delta))
200 arraytext = addlist
198 arraytext = addlist
201 text = util.buffer(arraytext)
199 text = util.buffer(arraytext)
202
200
203 n = self.addrevision(text, transaction, link, p1, p2, cachedelta)
201 n = self.addrevision(text, transaction, link, p1, p2, cachedelta)
204 self._mancache = (n, map, arraytext)
202 self._mancache = (n, map, arraytext)
205
203
206 return n
204 return n
@@ -1,331 +1,327
1 # obsolete.py - obsolete markers handling
1 # obsolete.py - obsolete markers handling
2 #
2 #
3 # Copyright 2012 Pierre-Yves David <pierre-yves.david@ens-lyon.org>
3 # Copyright 2012 Pierre-Yves David <pierre-yves.david@ens-lyon.org>
4 # Logilab SA <contact@logilab.fr>
4 # Logilab SA <contact@logilab.fr>
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 """Obsolete markers handling
9 """Obsolete markers handling
10
10
11 An obsolete marker maps an old changeset to a list of new
11 An obsolete marker maps an old changeset to a list of new
12 changesets. If the list of new changesets is empty, the old changeset
12 changesets. If the list of new changesets is empty, the old changeset
13 is said to be "killed". Otherwise, the old changeset is being
13 is said to be "killed". Otherwise, the old changeset is being
14 "replaced" by the new changesets.
14 "replaced" by the new changesets.
15
15
16 Obsolete markers can be used to record and distribute changeset graph
16 Obsolete markers can be used to record and distribute changeset graph
17 transformations performed by history rewriting operations, and help
17 transformations performed by history rewriting operations, and help
18 building new tools to reconciliate conflicting rewriting actions. To
18 building new tools to reconciliate conflicting rewriting actions. To
19 facilitate conflicts resolution, markers include various annotations
19 facilitate conflicts resolution, markers include various annotations
20 besides old and news changeset identifiers, such as creation date or
20 besides old and news changeset identifiers, such as creation date or
21 author name.
21 author name.
22
22
23
23
24 Format
24 Format
25 ------
25 ------
26
26
27 Markers are stored in an append-only file stored in
27 Markers are stored in an append-only file stored in
28 '.hg/store/obsstore'.
28 '.hg/store/obsstore'.
29
29
30 The file starts with a version header:
30 The file starts with a version header:
31
31
32 - 1 unsigned byte: version number, starting at zero.
32 - 1 unsigned byte: version number, starting at zero.
33
33
34
34
35 The header is followed by the markers. Each marker is made of:
35 The header is followed by the markers. Each marker is made of:
36
36
37 - 1 unsigned byte: number of new changesets "R", could be zero.
37 - 1 unsigned byte: number of new changesets "R", could be zero.
38
38
39 - 1 unsigned 32-bits integer: metadata size "M" in bytes.
39 - 1 unsigned 32-bits integer: metadata size "M" in bytes.
40
40
41 - 1 byte: a bit field. It is reserved for flags used in obsolete
41 - 1 byte: a bit field. It is reserved for flags used in obsolete
42 markers common operations, to avoid repeated decoding of metadata
42 markers common operations, to avoid repeated decoding of metadata
43 entries.
43 entries.
44
44
45 - 20 bytes: obsoleted changeset identifier.
45 - 20 bytes: obsoleted changeset identifier.
46
46
47 - N*20 bytes: new changesets identifiers.
47 - N*20 bytes: new changesets identifiers.
48
48
49 - M bytes: metadata as a sequence of nul-terminated strings. Each
49 - M bytes: metadata as a sequence of nul-terminated strings. Each
50 string contains a key and a value, separated by a color ':', without
50 string contains a key and a value, separated by a color ':', without
51 additional encoding. Keys cannot contain '\0' or ':' and values
51 additional encoding. Keys cannot contain '\0' or ':' and values
52 cannot contain '\0'.
52 cannot contain '\0'.
53 """
53 """
54 import struct
54 import struct
55 import util, base85
55 import util, base85
56 from i18n import _
56 from i18n import _
57
57
58 # the obsolete feature is not mature enought to be enabled by default.
59 # you have to rely on third party extension extension to enable this.
60 _enabled = False
61
62 _pack = struct.pack
58 _pack = struct.pack
63 _unpack = struct.unpack
59 _unpack = struct.unpack
64
60
65 # the obsolete feature is not mature enough to be enabled by default.
61 # the obsolete feature is not mature enough to be enabled by default.
66 # you have to rely on third party extension extension to enable this.
62 # you have to rely on third party extension extension to enable this.
67 _enabled = False
63 _enabled = False
68
64
69 # data used for parsing and writing
65 # data used for parsing and writing
70 _fmversion = 0
66 _fmversion = 0
71 _fmfixed = '>BIB20s'
67 _fmfixed = '>BIB20s'
72 _fmnode = '20s'
68 _fmnode = '20s'
73 _fmfsize = struct.calcsize(_fmfixed)
69 _fmfsize = struct.calcsize(_fmfixed)
74 _fnodesize = struct.calcsize(_fmnode)
70 _fnodesize = struct.calcsize(_fmnode)
75
71
76 def _readmarkers(data):
72 def _readmarkers(data):
77 """Read and enumerate markers from raw data"""
73 """Read and enumerate markers from raw data"""
78 off = 0
74 off = 0
79 diskversion = _unpack('>B', data[off:off + 1])[0]
75 diskversion = _unpack('>B', data[off:off + 1])[0]
80 off += 1
76 off += 1
81 if diskversion != _fmversion:
77 if diskversion != _fmversion:
82 raise util.Abort(_('parsing obsolete marker: unknown version %r')
78 raise util.Abort(_('parsing obsolete marker: unknown version %r')
83 % diskversion)
79 % diskversion)
84
80
85 # Loop on markers
81 # Loop on markers
86 l = len(data)
82 l = len(data)
87 while off + _fmfsize <= l:
83 while off + _fmfsize <= l:
88 # read fixed part
84 # read fixed part
89 cur = data[off:off + _fmfsize]
85 cur = data[off:off + _fmfsize]
90 off += _fmfsize
86 off += _fmfsize
91 nbsuc, mdsize, flags, pre = _unpack(_fmfixed, cur)
87 nbsuc, mdsize, flags, pre = _unpack(_fmfixed, cur)
92 # read replacement
88 # read replacement
93 sucs = ()
89 sucs = ()
94 if nbsuc:
90 if nbsuc:
95 s = (_fnodesize * nbsuc)
91 s = (_fnodesize * nbsuc)
96 cur = data[off:off + s]
92 cur = data[off:off + s]
97 sucs = _unpack(_fmnode * nbsuc, cur)
93 sucs = _unpack(_fmnode * nbsuc, cur)
98 off += s
94 off += s
99 # read metadata
95 # read metadata
100 # (metadata will be decoded on demand)
96 # (metadata will be decoded on demand)
101 metadata = data[off:off + mdsize]
97 metadata = data[off:off + mdsize]
102 if len(metadata) != mdsize:
98 if len(metadata) != mdsize:
103 raise util.Abort(_('parsing obsolete marker: metadata is too '
99 raise util.Abort(_('parsing obsolete marker: metadata is too '
104 'short, %d bytes expected, got %d')
100 'short, %d bytes expected, got %d')
105 % (mdsize, len(metadata)))
101 % (mdsize, len(metadata)))
106 off += mdsize
102 off += mdsize
107 yield (pre, sucs, flags, metadata)
103 yield (pre, sucs, flags, metadata)
108
104
109 def encodemeta(meta):
105 def encodemeta(meta):
110 """Return encoded metadata string to string mapping.
106 """Return encoded metadata string to string mapping.
111
107
112 Assume no ':' in key and no '\0' in both key and value."""
108 Assume no ':' in key and no '\0' in both key and value."""
113 for key, value in meta.iteritems():
109 for key, value in meta.iteritems():
114 if ':' in key or '\0' in key:
110 if ':' in key or '\0' in key:
115 raise ValueError("':' and '\0' are forbidden in metadata key'")
111 raise ValueError("':' and '\0' are forbidden in metadata key'")
116 if '\0' in value:
112 if '\0' in value:
117 raise ValueError("':' are forbidden in metadata value'")
113 raise ValueError("':' are forbidden in metadata value'")
118 return '\0'.join(['%s:%s' % (k, meta[k]) for k in sorted(meta)])
114 return '\0'.join(['%s:%s' % (k, meta[k]) for k in sorted(meta)])
119
115
120 def decodemeta(data):
116 def decodemeta(data):
121 """Return string to string dictionary from encoded version."""
117 """Return string to string dictionary from encoded version."""
122 d = {}
118 d = {}
123 for l in data.split('\0'):
119 for l in data.split('\0'):
124 if l:
120 if l:
125 key, value = l.split(':')
121 key, value = l.split(':')
126 d[key] = value
122 d[key] = value
127 return d
123 return d
128
124
129 class marker(object):
125 class marker(object):
130 """Wrap obsolete marker raw data"""
126 """Wrap obsolete marker raw data"""
131
127
132 def __init__(self, repo, data):
128 def __init__(self, repo, data):
133 # the repo argument will be used to create changectx in later version
129 # the repo argument will be used to create changectx in later version
134 self._repo = repo
130 self._repo = repo
135 self._data = data
131 self._data = data
136 self._decodedmeta = None
132 self._decodedmeta = None
137
133
138 def precnode(self):
134 def precnode(self):
139 """Precursor changeset node identifier"""
135 """Precursor changeset node identifier"""
140 return self._data[0]
136 return self._data[0]
141
137
142 def succnodes(self):
138 def succnodes(self):
143 """List of successor changesets node identifiers"""
139 """List of successor changesets node identifiers"""
144 return self._data[1]
140 return self._data[1]
145
141
146 def metadata(self):
142 def metadata(self):
147 """Decoded metadata dictionary"""
143 """Decoded metadata dictionary"""
148 if self._decodedmeta is None:
144 if self._decodedmeta is None:
149 self._decodedmeta = decodemeta(self._data[3])
145 self._decodedmeta = decodemeta(self._data[3])
150 return self._decodedmeta
146 return self._decodedmeta
151
147
152 def date(self):
148 def date(self):
153 """Creation date as (unixtime, offset)"""
149 """Creation date as (unixtime, offset)"""
154 parts = self.metadata()['date'].split(' ')
150 parts = self.metadata()['date'].split(' ')
155 return (float(parts[0]), int(parts[1]))
151 return (float(parts[0]), int(parts[1]))
156
152
157 class obsstore(object):
153 class obsstore(object):
158 """Store obsolete markers
154 """Store obsolete markers
159
155
160 Markers can be accessed with two mappings:
156 Markers can be accessed with two mappings:
161 - precursors: old -> set(new)
157 - precursors: old -> set(new)
162 - successors: new -> set(old)
158 - successors: new -> set(old)
163 """
159 """
164
160
165 def __init__(self, sopener):
161 def __init__(self, sopener):
166 self._all = []
162 self._all = []
167 # new markers to serialize
163 # new markers to serialize
168 self.precursors = {}
164 self.precursors = {}
169 self.successors = {}
165 self.successors = {}
170 self.sopener = sopener
166 self.sopener = sopener
171 data = sopener.tryread('obsstore')
167 data = sopener.tryread('obsstore')
172 if data:
168 if data:
173 self._load(_readmarkers(data))
169 self._load(_readmarkers(data))
174
170
175 def __iter__(self):
171 def __iter__(self):
176 return iter(self._all)
172 return iter(self._all)
177
173
178 def __nonzero__(self):
174 def __nonzero__(self):
179 return bool(self._all)
175 return bool(self._all)
180
176
181 def create(self, transaction, prec, succs=(), flag=0, metadata=None):
177 def create(self, transaction, prec, succs=(), flag=0, metadata=None):
182 """obsolete: add a new obsolete marker
178 """obsolete: add a new obsolete marker
183
179
184 * ensuring it is hashable
180 * ensuring it is hashable
185 * check mandatory metadata
181 * check mandatory metadata
186 * encode metadata
182 * encode metadata
187 """
183 """
188 if metadata is None:
184 if metadata is None:
189 metadata = {}
185 metadata = {}
190 if len(prec) != 20:
186 if len(prec) != 20:
191 raise ValueError(prec)
187 raise ValueError(prec)
192 for succ in succs:
188 for succ in succs:
193 if len(succ) != 20:
189 if len(succ) != 20:
194 raise ValueError(succ)
190 raise ValueError(succ)
195 marker = (str(prec), tuple(succs), int(flag), encodemeta(metadata))
191 marker = (str(prec), tuple(succs), int(flag), encodemeta(metadata))
196 self.add(transaction, [marker])
192 self.add(transaction, [marker])
197
193
198 def add(self, transaction, markers):
194 def add(self, transaction, markers):
199 """Add new markers to the store
195 """Add new markers to the store
200
196
201 Take care of filtering duplicate.
197 Take care of filtering duplicate.
202 Return the number of new marker."""
198 Return the number of new marker."""
203 if not _enabled:
199 if not _enabled:
204 raise util.Abort('obsolete feature is not enabled on this repo')
200 raise util.Abort('obsolete feature is not enabled on this repo')
205 new = [m for m in markers if m not in self._all]
201 new = [m for m in markers if m not in self._all]
206 if new:
202 if new:
207 f = self.sopener('obsstore', 'ab')
203 f = self.sopener('obsstore', 'ab')
208 try:
204 try:
209 # Whether the file's current position is at the begin or at
205 # Whether the file's current position is at the begin or at
210 # the end after opening a file for appending is implementation
206 # the end after opening a file for appending is implementation
211 # defined. So we must seek to the end before calling tell(),
207 # defined. So we must seek to the end before calling tell(),
212 # or we may get a zero offset for non-zero sized files on
208 # or we may get a zero offset for non-zero sized files on
213 # some platforms (issue3543).
209 # some platforms (issue3543).
214 f.seek(0, 2) # os.SEEK_END
210 f.seek(0, 2) # os.SEEK_END
215 offset = f.tell()
211 offset = f.tell()
216 transaction.add('obsstore', offset)
212 transaction.add('obsstore', offset)
217 # offset == 0: new file - add the version header
213 # offset == 0: new file - add the version header
218 for bytes in _encodemarkers(new, offset == 0):
214 for bytes in _encodemarkers(new, offset == 0):
219 f.write(bytes)
215 f.write(bytes)
220 finally:
216 finally:
221 # XXX: f.close() == filecache invalidation == obsstore rebuilt.
217 # XXX: f.close() == filecache invalidation == obsstore rebuilt.
222 # call 'filecacheentry.refresh()' here
218 # call 'filecacheentry.refresh()' here
223 f.close()
219 f.close()
224 self._load(new)
220 self._load(new)
225 return len(new)
221 return len(new)
226
222
227 def mergemarkers(self, transation, data):
223 def mergemarkers(self, transation, data):
228 markers = _readmarkers(data)
224 markers = _readmarkers(data)
229 self.add(transation, markers)
225 self.add(transation, markers)
230
226
231 def _load(self, markers):
227 def _load(self, markers):
232 for mark in markers:
228 for mark in markers:
233 self._all.append(mark)
229 self._all.append(mark)
234 pre, sucs = mark[:2]
230 pre, sucs = mark[:2]
235 self.precursors.setdefault(pre, set()).add(mark)
231 self.precursors.setdefault(pre, set()).add(mark)
236 for suc in sucs:
232 for suc in sucs:
237 self.successors.setdefault(suc, set()).add(mark)
233 self.successors.setdefault(suc, set()).add(mark)
238
234
239 def _encodemarkers(markers, addheader=False):
235 def _encodemarkers(markers, addheader=False):
240 # Kept separate from flushmarkers(), it will be reused for
236 # Kept separate from flushmarkers(), it will be reused for
241 # markers exchange.
237 # markers exchange.
242 if addheader:
238 if addheader:
243 yield _pack('>B', _fmversion)
239 yield _pack('>B', _fmversion)
244 for marker in markers:
240 for marker in markers:
245 yield _encodeonemarker(marker)
241 yield _encodeonemarker(marker)
246
242
247
243
248 def _encodeonemarker(marker):
244 def _encodeonemarker(marker):
249 pre, sucs, flags, metadata = marker
245 pre, sucs, flags, metadata = marker
250 nbsuc = len(sucs)
246 nbsuc = len(sucs)
251 format = _fmfixed + (_fmnode * nbsuc)
247 format = _fmfixed + (_fmnode * nbsuc)
252 data = [nbsuc, len(metadata), flags, pre]
248 data = [nbsuc, len(metadata), flags, pre]
253 data.extend(sucs)
249 data.extend(sucs)
254 return _pack(format, *data) + metadata
250 return _pack(format, *data) + metadata
255
251
256 # arbitrary picked to fit into 8K limit from HTTP server
252 # arbitrary picked to fit into 8K limit from HTTP server
257 # you have to take in account:
253 # you have to take in account:
258 # - the version header
254 # - the version header
259 # - the base85 encoding
255 # - the base85 encoding
260 _maxpayload = 5300
256 _maxpayload = 5300
261
257
262 def listmarkers(repo):
258 def listmarkers(repo):
263 """List markers over pushkey"""
259 """List markers over pushkey"""
264 if not repo.obsstore:
260 if not repo.obsstore:
265 return {}
261 return {}
266 keys = {}
262 keys = {}
267 parts = []
263 parts = []
268 currentlen = _maxpayload * 2 # ensure we create a new part
264 currentlen = _maxpayload * 2 # ensure we create a new part
269 for marker in repo.obsstore:
265 for marker in repo.obsstore:
270 nextdata = _encodeonemarker(marker)
266 nextdata = _encodeonemarker(marker)
271 if (len(nextdata) + currentlen > _maxpayload):
267 if (len(nextdata) + currentlen > _maxpayload):
272 currentpart = []
268 currentpart = []
273 currentlen = 0
269 currentlen = 0
274 parts.append(currentpart)
270 parts.append(currentpart)
275 currentpart.append(nextdata)
271 currentpart.append(nextdata)
276 currentlen += len(nextdata)
272 currentlen += len(nextdata)
277 for idx, part in enumerate(reversed(parts)):
273 for idx, part in enumerate(reversed(parts)):
278 data = ''.join([_pack('>B', _fmversion)] + part)
274 data = ''.join([_pack('>B', _fmversion)] + part)
279 keys['dump%i' % idx] = base85.b85encode(data)
275 keys['dump%i' % idx] = base85.b85encode(data)
280 return keys
276 return keys
281
277
282 def pushmarker(repo, key, old, new):
278 def pushmarker(repo, key, old, new):
283 """Push markers over pushkey"""
279 """Push markers over pushkey"""
284 if not key.startswith('dump'):
280 if not key.startswith('dump'):
285 repo.ui.warn(_('unknown key: %r') % key)
281 repo.ui.warn(_('unknown key: %r') % key)
286 return 0
282 return 0
287 if old:
283 if old:
288 repo.ui.warn(_('unexpected old value') % key)
284 repo.ui.warn(_('unexpected old value') % key)
289 return 0
285 return 0
290 data = base85.b85decode(new)
286 data = base85.b85decode(new)
291 lock = repo.lock()
287 lock = repo.lock()
292 try:
288 try:
293 tr = repo.transaction('pushkey: obsolete markers')
289 tr = repo.transaction('pushkey: obsolete markers')
294 try:
290 try:
295 repo.obsstore.mergemarkers(tr, data)
291 repo.obsstore.mergemarkers(tr, data)
296 tr.close()
292 tr.close()
297 return 1
293 return 1
298 finally:
294 finally:
299 tr.release()
295 tr.release()
300 finally:
296 finally:
301 lock.release()
297 lock.release()
302
298
303 def allmarkers(repo):
299 def allmarkers(repo):
304 """all obsolete markers known in a repository"""
300 """all obsolete markers known in a repository"""
305 for markerdata in repo.obsstore:
301 for markerdata in repo.obsstore:
306 yield marker(repo, markerdata)
302 yield marker(repo, markerdata)
307
303
308 def precursormarkers(ctx):
304 def precursormarkers(ctx):
309 """obsolete marker making this changeset obsolete"""
305 """obsolete marker making this changeset obsolete"""
310 for data in ctx._repo.obsstore.precursors.get(ctx.node(), ()):
306 for data in ctx._repo.obsstore.precursors.get(ctx.node(), ()):
311 yield marker(ctx._repo, data)
307 yield marker(ctx._repo, data)
312
308
313 def successormarkers(ctx):
309 def successormarkers(ctx):
314 """obsolete marker marking this changeset as a successors"""
310 """obsolete marker marking this changeset as a successors"""
315 for data in ctx._repo.obsstore.successors.get(ctx.node(), ()):
311 for data in ctx._repo.obsstore.successors.get(ctx.node(), ()):
316 yield marker(ctx._repo, data)
312 yield marker(ctx._repo, data)
317
313
318 def anysuccessors(obsstore, node):
314 def anysuccessors(obsstore, node):
319 """Yield every successor of <node>
315 """Yield every successor of <node>
320
316
321 This is a linear yield unsuitable to detect split changesets."""
317 This is a linear yield unsuitable to detect split changesets."""
322 remaining = set([node])
318 remaining = set([node])
323 seen = set(remaining)
319 seen = set(remaining)
324 while remaining:
320 while remaining:
325 current = remaining.pop()
321 current = remaining.pop()
326 yield current
322 yield current
327 for mark in obsstore.precursors.get(current, ()):
323 for mark in obsstore.precursors.get(current, ()):
328 for suc in mark[1]:
324 for suc in mark[1]:
329 if suc not in seen:
325 if suc not in seen:
330 seen.add(suc)
326 seen.add(suc)
331 remaining.add(suc)
327 remaining.add(suc)
@@ -1,200 +1,197
1 # setdiscovery.py - improved discovery of common nodeset for mercurial
1 # setdiscovery.py - improved discovery of common nodeset for mercurial
2 #
2 #
3 # Copyright 2010 Benoit Boissinot <bboissin@gmail.com>
3 # Copyright 2010 Benoit Boissinot <bboissin@gmail.com>
4 # and Peter Arrenbrecht <peter@arrenbrecht.ch>
4 # and Peter Arrenbrecht <peter@arrenbrecht.ch>
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 from node import nullid
9 from node import nullid
10 from i18n import _
10 from i18n import _
11 import random, util, dagutil
11 import random, util, dagutil
12
12
13 def _updatesample(dag, nodes, sample, always, quicksamplesize=0):
13 def _updatesample(dag, nodes, sample, always, quicksamplesize=0):
14 # if nodes is empty we scan the entire graph
14 # if nodes is empty we scan the entire graph
15 if nodes:
15 if nodes:
16 heads = dag.headsetofconnecteds(nodes)
16 heads = dag.headsetofconnecteds(nodes)
17 else:
17 else:
18 heads = dag.heads()
18 heads = dag.heads()
19 dist = {}
19 dist = {}
20 visit = util.deque(heads)
20 visit = util.deque(heads)
21 seen = set()
21 seen = set()
22 factor = 1
22 factor = 1
23 while visit:
23 while visit:
24 curr = visit.popleft()
24 curr = visit.popleft()
25 if curr in seen:
25 if curr in seen:
26 continue
26 continue
27 d = dist.setdefault(curr, 1)
27 d = dist.setdefault(curr, 1)
28 if d > factor:
28 if d > factor:
29 factor *= 2
29 factor *= 2
30 if d == factor:
30 if d == factor:
31 if curr not in always: # need this check for the early exit below
31 if curr not in always: # need this check for the early exit below
32 sample.add(curr)
32 sample.add(curr)
33 if quicksamplesize and (len(sample) >= quicksamplesize):
33 if quicksamplesize and (len(sample) >= quicksamplesize):
34 return
34 return
35 seen.add(curr)
35 seen.add(curr)
36 for p in dag.parents(curr):
36 for p in dag.parents(curr):
37 if not nodes or p in nodes:
37 if not nodes or p in nodes:
38 dist.setdefault(p, d + 1)
38 dist.setdefault(p, d + 1)
39 visit.append(p)
39 visit.append(p)
40
40
41 def _setupsample(dag, nodes, size):
41 def _setupsample(dag, nodes, size):
42 if len(nodes) <= size:
42 if len(nodes) <= size:
43 return set(nodes), None, 0
43 return set(nodes), None, 0
44 always = dag.headsetofconnecteds(nodes)
44 always = dag.headsetofconnecteds(nodes)
45 desiredlen = size - len(always)
45 desiredlen = size - len(always)
46 if desiredlen <= 0:
46 if desiredlen <= 0:
47 # This could be bad if there are very many heads, all unknown to the
47 # This could be bad if there are very many heads, all unknown to the
48 # server. We're counting on long request support here.
48 # server. We're counting on long request support here.
49 return always, None, desiredlen
49 return always, None, desiredlen
50 return always, set(), desiredlen
50 return always, set(), desiredlen
51
51
52 def _takequicksample(dag, nodes, size, initial):
52 def _takequicksample(dag, nodes, size, initial):
53 always, sample, desiredlen = _setupsample(dag, nodes, size)
53 always, sample, desiredlen = _setupsample(dag, nodes, size)
54 if sample is None:
54 if sample is None:
55 return always
55 return always
56 if initial:
56 if initial:
57 fromset = None
57 fromset = None
58 else:
58 else:
59 fromset = nodes
59 fromset = nodes
60 _updatesample(dag, fromset, sample, always, quicksamplesize=desiredlen)
60 _updatesample(dag, fromset, sample, always, quicksamplesize=desiredlen)
61 sample.update(always)
61 sample.update(always)
62 return sample
62 return sample
63
63
64 def _takefullsample(dag, nodes, size):
64 def _takefullsample(dag, nodes, size):
65 always, sample, desiredlen = _setupsample(dag, nodes, size)
65 always, sample, desiredlen = _setupsample(dag, nodes, size)
66 if sample is None:
66 if sample is None:
67 return always
67 return always
68 # update from heads
68 # update from heads
69 _updatesample(dag, nodes, sample, always)
69 _updatesample(dag, nodes, sample, always)
70 # update from roots
70 # update from roots
71 _updatesample(dag.inverse(), nodes, sample, always)
71 _updatesample(dag.inverse(), nodes, sample, always)
72 assert sample
72 assert sample
73 if len(sample) > desiredlen:
73 if len(sample) > desiredlen:
74 sample = set(random.sample(sample, desiredlen))
74 sample = set(random.sample(sample, desiredlen))
75 elif len(sample) < desiredlen:
75 elif len(sample) < desiredlen:
76 more = desiredlen - len(sample)
76 more = desiredlen - len(sample)
77 sample.update(random.sample(list(nodes - sample - always), more))
77 sample.update(random.sample(list(nodes - sample - always), more))
78 sample.update(always)
78 sample.update(always)
79 return sample
79 return sample
80
80
81 def findcommonheads(ui, local, remote,
81 def findcommonheads(ui, local, remote,
82 initialsamplesize=100,
82 initialsamplesize=100,
83 fullsamplesize=200,
83 fullsamplesize=200,
84 abortwhenunrelated=True):
84 abortwhenunrelated=True):
85 '''Return a tuple (common, anyincoming, remoteheads) used to identify
85 '''Return a tuple (common, anyincoming, remoteheads) used to identify
86 missing nodes from or in remote.
86 missing nodes from or in remote.
87
88 shortcutlocal determines whether we try use direct access to localrepo if
89 remote is actually local.
90 '''
87 '''
91 roundtrips = 0
88 roundtrips = 0
92 cl = local.changelog
89 cl = local.changelog
93 dag = dagutil.revlogdag(cl)
90 dag = dagutil.revlogdag(cl)
94
91
95 # early exit if we know all the specified remote heads already
92 # early exit if we know all the specified remote heads already
96 ui.debug("query 1; heads\n")
93 ui.debug("query 1; heads\n")
97 roundtrips += 1
94 roundtrips += 1
98 ownheads = dag.heads()
95 ownheads = dag.heads()
99 sample = ownheads
96 sample = ownheads
100 if remote.local():
97 if remote.local():
101 # stopgap until we have a proper localpeer that supports batch()
98 # stopgap until we have a proper localpeer that supports batch()
102 srvheadhashes = remote.heads()
99 srvheadhashes = remote.heads()
103 yesno = remote.known(dag.externalizeall(sample))
100 yesno = remote.known(dag.externalizeall(sample))
104 elif remote.capable('batch'):
101 elif remote.capable('batch'):
105 batch = remote.batch()
102 batch = remote.batch()
106 srvheadhashesref = batch.heads()
103 srvheadhashesref = batch.heads()
107 yesnoref = batch.known(dag.externalizeall(sample))
104 yesnoref = batch.known(dag.externalizeall(sample))
108 batch.submit()
105 batch.submit()
109 srvheadhashes = srvheadhashesref.value
106 srvheadhashes = srvheadhashesref.value
110 yesno = yesnoref.value
107 yesno = yesnoref.value
111 else:
108 else:
112 # compatibility with pre-batch, but post-known remotes during 1.9
109 # compatibility with pre-batch, but post-known remotes during 1.9
113 # development
110 # development
114 srvheadhashes = remote.heads()
111 srvheadhashes = remote.heads()
115 sample = []
112 sample = []
116
113
117 if cl.tip() == nullid:
114 if cl.tip() == nullid:
118 if srvheadhashes != [nullid]:
115 if srvheadhashes != [nullid]:
119 return [nullid], True, srvheadhashes
116 return [nullid], True, srvheadhashes
120 return [nullid], False, []
117 return [nullid], False, []
121
118
122 # start actual discovery (we note this before the next "if" for
119 # start actual discovery (we note this before the next "if" for
123 # compatibility reasons)
120 # compatibility reasons)
124 ui.status(_("searching for changes\n"))
121 ui.status(_("searching for changes\n"))
125
122
126 srvheads = dag.internalizeall(srvheadhashes, filterunknown=True)
123 srvheads = dag.internalizeall(srvheadhashes, filterunknown=True)
127 if len(srvheads) == len(srvheadhashes):
124 if len(srvheads) == len(srvheadhashes):
128 ui.debug("all remote heads known locally\n")
125 ui.debug("all remote heads known locally\n")
129 return (srvheadhashes, False, srvheadhashes,)
126 return (srvheadhashes, False, srvheadhashes,)
130
127
131 if sample and util.all(yesno):
128 if sample and util.all(yesno):
132 ui.note(_("all local heads known remotely\n"))
129 ui.note(_("all local heads known remotely\n"))
133 ownheadhashes = dag.externalizeall(ownheads)
130 ownheadhashes = dag.externalizeall(ownheads)
134 return (ownheadhashes, True, srvheadhashes,)
131 return (ownheadhashes, True, srvheadhashes,)
135
132
136 # full blown discovery
133 # full blown discovery
137
134
138 # own nodes where I don't know if remote knows them
135 # own nodes where I don't know if remote knows them
139 undecided = dag.nodeset()
136 undecided = dag.nodeset()
140 # own nodes I know we both know
137 # own nodes I know we both know
141 common = set()
138 common = set()
142 # own nodes I know remote lacks
139 # own nodes I know remote lacks
143 missing = set()
140 missing = set()
144
141
145 # treat remote heads (and maybe own heads) as a first implicit sample
142 # treat remote heads (and maybe own heads) as a first implicit sample
146 # response
143 # response
147 common.update(dag.ancestorset(srvheads))
144 common.update(dag.ancestorset(srvheads))
148 undecided.difference_update(common)
145 undecided.difference_update(common)
149
146
150 full = False
147 full = False
151 while undecided:
148 while undecided:
152
149
153 if sample:
150 if sample:
154 commoninsample = set(n for i, n in enumerate(sample) if yesno[i])
151 commoninsample = set(n for i, n in enumerate(sample) if yesno[i])
155 common.update(dag.ancestorset(commoninsample, common))
152 common.update(dag.ancestorset(commoninsample, common))
156
153
157 missinginsample = [n for i, n in enumerate(sample) if not yesno[i]]
154 missinginsample = [n for i, n in enumerate(sample) if not yesno[i]]
158 missing.update(dag.descendantset(missinginsample, missing))
155 missing.update(dag.descendantset(missinginsample, missing))
159
156
160 undecided.difference_update(missing)
157 undecided.difference_update(missing)
161 undecided.difference_update(common)
158 undecided.difference_update(common)
162
159
163 if not undecided:
160 if not undecided:
164 break
161 break
165
162
166 if full:
163 if full:
167 ui.note(_("sampling from both directions\n"))
164 ui.note(_("sampling from both directions\n"))
168 sample = _takefullsample(dag, undecided, size=fullsamplesize)
165 sample = _takefullsample(dag, undecided, size=fullsamplesize)
169 elif common:
166 elif common:
170 # use cheapish initial sample
167 # use cheapish initial sample
171 ui.debug("taking initial sample\n")
168 ui.debug("taking initial sample\n")
172 sample = _takefullsample(dag, undecided, size=fullsamplesize)
169 sample = _takefullsample(dag, undecided, size=fullsamplesize)
173 else:
170 else:
174 # use even cheaper initial sample
171 # use even cheaper initial sample
175 ui.debug("taking quick initial sample\n")
172 ui.debug("taking quick initial sample\n")
176 sample = _takequicksample(dag, undecided, size=initialsamplesize,
173 sample = _takequicksample(dag, undecided, size=initialsamplesize,
177 initial=True)
174 initial=True)
178
175
179 roundtrips += 1
176 roundtrips += 1
180 ui.progress(_('searching'), roundtrips, unit=_('queries'))
177 ui.progress(_('searching'), roundtrips, unit=_('queries'))
181 ui.debug("query %i; still undecided: %i, sample size is: %i\n"
178 ui.debug("query %i; still undecided: %i, sample size is: %i\n"
182 % (roundtrips, len(undecided), len(sample)))
179 % (roundtrips, len(undecided), len(sample)))
183 # indices between sample and externalized version must match
180 # indices between sample and externalized version must match
184 sample = list(sample)
181 sample = list(sample)
185 yesno = remote.known(dag.externalizeall(sample))
182 yesno = remote.known(dag.externalizeall(sample))
186 full = True
183 full = True
187
184
188 result = dag.headsetofconnecteds(common)
185 result = dag.headsetofconnecteds(common)
189 ui.progress(_('searching'), None)
186 ui.progress(_('searching'), None)
190 ui.debug("%d total queries\n" % roundtrips)
187 ui.debug("%d total queries\n" % roundtrips)
191
188
192 if not result and srvheadhashes != [nullid]:
189 if not result and srvheadhashes != [nullid]:
193 if abortwhenunrelated:
190 if abortwhenunrelated:
194 raise util.Abort(_("repository is unrelated"))
191 raise util.Abort(_("repository is unrelated"))
195 else:
192 else:
196 ui.warn(_("warning: repository is unrelated\n"))
193 ui.warn(_("warning: repository is unrelated\n"))
197 return (set([nullid]), True, srvheadhashes,)
194 return (set([nullid]), True, srvheadhashes,)
198
195
199 anyincoming = (srvheadhashes != [nullid])
196 anyincoming = (srvheadhashes != [nullid])
200 return dag.externalizeall(result), anyincoming, srvheadhashes
197 return dag.externalizeall(result), anyincoming, srvheadhashes
General Comments 0
You need to be logged in to leave comments. Login now