##// END OF EJS Templates
delete some dead comments and docstrings
Mads Kiilerich -
r17426:9724f8f8 default
parent child Browse files
Show More
@@ -1,360 +1,359
1 1 # monotone.py - monotone support for the convert extension
2 2 #
3 3 # Copyright 2008, 2009 Mikkel Fahnoe Jorgensen <mikkel@dvide.com> and
4 4 # others
5 5 #
6 6 # This software may be used and distributed according to the terms of the
7 7 # GNU General Public License version 2 or any later version.
8 8
9 9 import os, re
10 10 from mercurial import util
11 11 from common import NoRepo, commit, converter_source, checktool
12 12 from common import commandline
13 13 from mercurial.i18n import _
14 14
15 15 class monotone_source(converter_source, commandline):
16 16 def __init__(self, ui, path=None, rev=None):
17 17 converter_source.__init__(self, ui, path, rev)
18 18 commandline.__init__(self, ui, 'mtn')
19 19
20 20 self.ui = ui
21 21 self.path = path
22 22 self.automatestdio = False
23 23 self.rev = rev
24 24
25 25 norepo = NoRepo(_("%s does not look like a monotone repository")
26 26 % path)
27 27 if not os.path.exists(os.path.join(path, '_MTN')):
28 28 # Could be a monotone repository (SQLite db file)
29 29 try:
30 30 f = file(path, 'rb')
31 31 header = f.read(16)
32 32 f.close()
33 33 except IOError:
34 34 header = ''
35 35 if header != 'SQLite format 3\x00':
36 36 raise norepo
37 37
38 38 # regular expressions for parsing monotone output
39 39 space = r'\s*'
40 40 name = r'\s+"((?:\\"|[^"])*)"\s*'
41 41 value = name
42 42 revision = r'\s+\[(\w+)\]\s*'
43 43 lines = r'(?:.|\n)+'
44 44
45 45 self.dir_re = re.compile(space + "dir" + name)
46 46 self.file_re = re.compile(space + "file" + name +
47 47 "content" + revision)
48 48 self.add_file_re = re.compile(space + "add_file" + name +
49 49 "content" + revision)
50 50 self.patch_re = re.compile(space + "patch" + name +
51 51 "from" + revision + "to" + revision)
52 52 self.rename_re = re.compile(space + "rename" + name + "to" + name)
53 53 self.delete_re = re.compile(space + "delete" + name)
54 54 self.tag_re = re.compile(space + "tag" + name + "revision" +
55 55 revision)
56 56 self.cert_re = re.compile(lines + space + "name" + name +
57 57 "value" + value)
58 58
59 59 attr = space + "file" + lines + space + "attr" + space
60 60 self.attr_execute_re = re.compile(attr + '"mtn:execute"' +
61 61 space + '"true"')
62 62
63 63 # cached data
64 64 self.manifest_rev = None
65 65 self.manifest = None
66 66 self.files = None
67 67 self.dirs = None
68 68
69 69 checktool('mtn', abort=False)
70 70
71 71 def mtnrun(self, *args, **kwargs):
72 72 if self.automatestdio:
73 73 return self.mtnrunstdio(*args, **kwargs)
74 74 else:
75 75 return self.mtnrunsingle(*args, **kwargs)
76 76
77 77 def mtnrunsingle(self, *args, **kwargs):
78 78 kwargs['d'] = self.path
79 79 return self.run0('automate', *args, **kwargs)
80 80
81 81 def mtnrunstdio(self, *args, **kwargs):
82 82 # Prepare the command in automate stdio format
83 83 command = []
84 84 for k, v in kwargs.iteritems():
85 85 command.append("%s:%s" % (len(k), k))
86 86 if v:
87 87 command.append("%s:%s" % (len(v), v))
88 88 if command:
89 89 command.insert(0, 'o')
90 90 command.append('e')
91 91
92 92 command.append('l')
93 93 for arg in args:
94 94 command += "%s:%s" % (len(arg), arg)
95 95 command.append('e')
96 96 command = ''.join(command)
97 97
98 98 self.ui.debug("mtn: sending '%s'\n" % command)
99 99 self.mtnwritefp.write(command)
100 100 self.mtnwritefp.flush()
101 101
102 102 return self.mtnstdioreadcommandoutput(command)
103 103
104 104 def mtnstdioreadpacket(self):
105 105 read = None
106 106 commandnbr = ''
107 107 while read != ':':
108 108 read = self.mtnreadfp.read(1)
109 109 if not read:
110 110 raise util.Abort(_('bad mtn packet - no end of commandnbr'))
111 111 commandnbr += read
112 112 commandnbr = commandnbr[:-1]
113 113
114 114 stream = self.mtnreadfp.read(1)
115 115 if stream not in 'mewptl':
116 116 raise util.Abort(_('bad mtn packet - bad stream type %s') % stream)
117 117
118 118 read = self.mtnreadfp.read(1)
119 119 if read != ':':
120 120 raise util.Abort(_('bad mtn packet - no divider before size'))
121 121
122 122 read = None
123 123 lengthstr = ''
124 124 while read != ':':
125 125 read = self.mtnreadfp.read(1)
126 126 if not read:
127 127 raise util.Abort(_('bad mtn packet - no end of packet size'))
128 128 lengthstr += read
129 129 try:
130 130 length = long(lengthstr[:-1])
131 131 except TypeError:
132 132 raise util.Abort(_('bad mtn packet - bad packet size %s')
133 133 % lengthstr)
134 134
135 135 read = self.mtnreadfp.read(length)
136 136 if len(read) != length:
137 137 raise util.Abort(_("bad mtn packet - unable to read full packet "
138 138 "read %s of %s") % (len(read), length))
139 139
140 140 return (commandnbr, stream, length, read)
141 141
142 142 def mtnstdioreadcommandoutput(self, command):
143 143 retval = []
144 144 while True:
145 145 commandnbr, stream, length, output = self.mtnstdioreadpacket()
146 146 self.ui.debug('mtn: read packet %s:%s:%s\n' %
147 147 (commandnbr, stream, length))
148 148
149 149 if stream == 'l':
150 150 # End of command
151 151 if output != '0':
152 152 raise util.Abort(_("mtn command '%s' returned %s") %
153 153 (command, output))
154 154 break
155 155 elif stream in 'ew':
156 156 # Error, warning output
157 157 self.ui.warn(_('%s error:\n') % self.command)
158 158 self.ui.warn(output)
159 159 elif stream == 'p':
160 160 # Progress messages
161 161 self.ui.debug('mtn: ' + output)
162 162 elif stream == 'm':
163 163 # Main stream - command output
164 164 retval.append(output)
165 165
166 166 return ''.join(retval)
167 167
168 168 def mtnloadmanifest(self, rev):
169 169 if self.manifest_rev == rev:
170 170 return
171 171 self.manifest = self.mtnrun("get_manifest_of", rev).split("\n\n")
172 172 self.manifest_rev = rev
173 173 self.files = {}
174 174 self.dirs = {}
175 175
176 176 for e in self.manifest:
177 177 m = self.file_re.match(e)
178 178 if m:
179 179 attr = ""
180 180 name = m.group(1)
181 181 node = m.group(2)
182 182 if self.attr_execute_re.match(e):
183 183 attr += "x"
184 184 self.files[name] = (node, attr)
185 185 m = self.dir_re.match(e)
186 186 if m:
187 187 self.dirs[m.group(1)] = True
188 188
189 189 def mtnisfile(self, name, rev):
190 190 # a non-file could be a directory or a deleted or renamed file
191 191 self.mtnloadmanifest(rev)
192 192 return name in self.files
193 193
194 194 def mtnisdir(self, name, rev):
195 195 self.mtnloadmanifest(rev)
196 196 return name in self.dirs
197 197
198 198 def mtngetcerts(self, rev):
199 199 certs = {"author":"<missing>", "date":"<missing>",
200 200 "changelog":"<missing>", "branch":"<missing>"}
201 201 certlist = self.mtnrun("certs", rev)
202 202 # mtn < 0.45:
203 203 # key "test@selenic.com"
204 204 # mtn >= 0.45:
205 205 # key [ff58a7ffb771907c4ff68995eada1c4da068d328]
206 206 certlist = re.split('\n\n key ["\[]', certlist)
207 207 for e in certlist:
208 208 m = self.cert_re.match(e)
209 209 if m:
210 210 name, value = m.groups()
211 211 value = value.replace(r'\"', '"')
212 212 value = value.replace(r'\\', '\\')
213 213 certs[name] = value
214 214 # Monotone may have subsecond dates: 2005-02-05T09:39:12.364306
215 215 # and all times are stored in UTC
216 216 certs["date"] = certs["date"].split('.')[0] + " UTC"
217 217 return certs
218 218
219 219 # implement the converter_source interface:
220 220
221 221 def getheads(self):
222 222 if not self.rev:
223 223 return self.mtnrun("leaves").splitlines()
224 224 else:
225 225 return [self.rev]
226 226
227 227 def getchanges(self, rev):
228 #revision = self.mtncmd("get_revision %s" % rev).split("\n\n")
229 228 revision = self.mtnrun("get_revision", rev).split("\n\n")
230 229 files = {}
231 230 ignoremove = {}
232 231 renameddirs = []
233 232 copies = {}
234 233 for e in revision:
235 234 m = self.add_file_re.match(e)
236 235 if m:
237 236 files[m.group(1)] = rev
238 237 ignoremove[m.group(1)] = rev
239 238 m = self.patch_re.match(e)
240 239 if m:
241 240 files[m.group(1)] = rev
242 241 # Delete/rename is handled later when the convert engine
243 242 # discovers an IOError exception from getfile,
244 243 # but only if we add the "from" file to the list of changes.
245 244 m = self.delete_re.match(e)
246 245 if m:
247 246 files[m.group(1)] = rev
248 247 m = self.rename_re.match(e)
249 248 if m:
250 249 toname = m.group(2)
251 250 fromname = m.group(1)
252 251 if self.mtnisfile(toname, rev):
253 252 ignoremove[toname] = 1
254 253 copies[toname] = fromname
255 254 files[toname] = rev
256 255 files[fromname] = rev
257 256 elif self.mtnisdir(toname, rev):
258 257 renameddirs.append((fromname, toname))
259 258
260 259 # Directory renames can be handled only once we have recorded
261 260 # all new files
262 261 for fromdir, todir in renameddirs:
263 262 renamed = {}
264 263 for tofile in self.files:
265 264 if tofile in ignoremove:
266 265 continue
267 266 if tofile.startswith(todir + '/'):
268 267 renamed[tofile] = fromdir + tofile[len(todir):]
269 268 # Avoid chained moves like:
270 269 # d1(/a) => d3/d1(/a)
271 270 # d2 => d3
272 271 ignoremove[tofile] = 1
273 272 for tofile, fromfile in renamed.items():
274 273 self.ui.debug (_("copying file in renamed directory "
275 274 "from '%s' to '%s'")
276 275 % (fromfile, tofile), '\n')
277 276 files[tofile] = rev
278 277 copies[tofile] = fromfile
279 278 for fromfile in renamed.values():
280 279 files[fromfile] = rev
281 280
282 281 return (files.items(), copies)
283 282
284 283 def getfile(self, name, rev):
285 284 if not self.mtnisfile(name, rev):
286 285 raise IOError # file was deleted or renamed
287 286 try:
288 287 data = self.mtnrun("get_file_of", name, r=rev)
289 288 except Exception:
290 289 raise IOError # file was deleted or renamed
291 290 self.mtnloadmanifest(rev)
292 291 node, attr = self.files.get(name, (None, ""))
293 292 return data, attr
294 293
295 294 def getcommit(self, rev):
296 295 extra = {}
297 296 certs = self.mtngetcerts(rev)
298 297 if certs.get('suspend') == certs["branch"]:
299 298 extra['close'] = '1'
300 299 return commit(
301 300 author=certs["author"],
302 301 date=util.datestr(util.strdate(certs["date"], "%Y-%m-%dT%H:%M:%S")),
303 302 desc=certs["changelog"],
304 303 rev=rev,
305 304 parents=self.mtnrun("parents", rev).splitlines(),
306 305 branch=certs["branch"],
307 306 extra=extra)
308 307
309 308 def gettags(self):
310 309 tags = {}
311 310 for e in self.mtnrun("tags").split("\n\n"):
312 311 m = self.tag_re.match(e)
313 312 if m:
314 313 tags[m.group(1)] = m.group(2)
315 314 return tags
316 315
317 316 def getchangedfiles(self, rev, i):
318 317 # This function is only needed to support --filemap
319 318 # ... and we don't support that
320 319 raise NotImplementedError
321 320
322 321 def before(self):
323 322 # Check if we have a new enough version to use automate stdio
324 323 version = 0.0
325 324 try:
326 325 versionstr = self.mtnrunsingle("interface_version")
327 326 version = float(versionstr)
328 327 except Exception:
329 328 raise util.Abort(_("unable to determine mtn automate interface "
330 329 "version"))
331 330
332 331 if version >= 12.0:
333 332 self.automatestdio = True
334 333 self.ui.debug("mtn automate version %s - using automate stdio\n" %
335 334 version)
336 335
337 336 # launch the long-running automate stdio process
338 337 self.mtnwritefp, self.mtnreadfp = self._run2('automate', 'stdio',
339 338 '-d', self.path)
340 339 # read the headers
341 340 read = self.mtnreadfp.readline()
342 341 if read != 'format-version: 2\n':
343 342 raise util.Abort(_('mtn automate stdio header unexpected: %s')
344 343 % read)
345 344 while read != '\n':
346 345 read = self.mtnreadfp.readline()
347 346 if not read:
348 347 raise util.Abort(_("failed to reach end of mtn automate "
349 348 "stdio headers"))
350 349 else:
351 350 self.ui.debug("mtn automate version %s - not using automate stdio "
352 351 "(automate >= 12.0 - mtn >= 0.46 is needed)\n" % version)
353 352
354 353 def after(self):
355 354 if self.automatestdio:
356 355 self.mtnwritefp.close()
357 356 self.mtnwritefp = None
358 357 self.mtnreadfp.close()
359 358 self.mtnreadfp = None
360 359
@@ -1,1582 +1,1581
1 1 """ Multicast DNS Service Discovery for Python, v0.12
2 2 Copyright (C) 2003, Paul Scott-Murphy
3 3
4 4 This module provides a framework for the use of DNS Service Discovery
5 5 using IP multicast. It has been tested against the JRendezvous
6 6 implementation from <a href="http://strangeberry.com">StrangeBerry</a>,
7 7 and against the mDNSResponder from Mac OS X 10.3.8.
8 8
9 9 This library is free software; you can redistribute it and/or
10 10 modify it under the terms of the GNU Lesser General Public
11 11 License as published by the Free Software Foundation; either
12 12 version 2.1 of the License, or (at your option) any later version.
13 13
14 14 This library is distributed in the hope that it will be useful,
15 15 but WITHOUT ANY WARRANTY; without even the implied warranty of
16 16 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 17 Lesser General Public License for more details.
18 18
19 19 You should have received a copy of the GNU Lesser General Public
20 20 License along with this library; if not, see
21 21 <http://www.gnu.org/licenses/>.
22 22
23 23 """
24 24
25 25 """0.12 update - allow selection of binding interface
26 26 typo fix - Thanks A. M. Kuchlingi
27 27 removed all use of word 'Rendezvous' - this is an API change"""
28 28
29 29 """0.11 update - correction to comments for addListener method
30 30 support for new record types seen from OS X
31 31 - IPv6 address
32 32 - hostinfo
33 33 ignore unknown DNS record types
34 34 fixes to name decoding
35 35 works alongside other processes using port 5353 (e.g. on Mac OS X)
36 36 tested against Mac OS X 10.3.2's mDNSResponder
37 37 corrections to removal of list entries for service browser"""
38 38
39 39 """0.10 update - Jonathon Paisley contributed these corrections:
40 40 always multicast replies, even when query is unicast
41 41 correct a pointer encoding problem
42 42 can now write records in any order
43 43 traceback shown on failure
44 44 better TXT record parsing
45 45 server is now separate from name
46 46 can cancel a service browser
47 47
48 48 modified some unit tests to accommodate these changes"""
49 49
50 50 """0.09 update - remove all records on service unregistration
51 51 fix DOS security problem with readName"""
52 52
53 53 """0.08 update - changed licensing to LGPL"""
54 54
55 55 """0.07 update - faster shutdown on engine
56 56 pointer encoding of outgoing names
57 57 ServiceBrowser now works
58 58 new unit tests"""
59 59
60 60 """0.06 update - small improvements with unit tests
61 61 added defined exception types
62 62 new style objects
63 63 fixed hostname/interface problem
64 64 fixed socket timeout problem
65 65 fixed addServiceListener() typo bug
66 66 using select() for socket reads
67 67 tested on Debian unstable with Python 2.2.2"""
68 68
69 69 """0.05 update - ensure case insensitivity on domain names
70 70 support for unicast DNS queries"""
71 71
72 72 """0.04 update - added some unit tests
73 73 added __ne__ adjuncts where required
74 74 ensure names end in '.local.'
75 75 timeout on receiving socket for clean shutdown"""
76 76
77 77 __author__ = "Paul Scott-Murphy"
78 78 __email__ = "paul at scott dash murphy dot com"
79 79 __version__ = "0.12"
80 80
81 81 import string
82 82 import time
83 83 import struct
84 84 import socket
85 85 import threading
86 86 import select
87 87 import traceback
88 88
89 89 __all__ = ["Zeroconf", "ServiceInfo", "ServiceBrowser"]
90 90
91 91 # hook for threads
92 92
93 93 globals()['_GLOBAL_DONE'] = 0
94 94
95 95 # Some timing constants
96 96
97 97 _UNREGISTER_TIME = 125
98 98 _CHECK_TIME = 175
99 99 _REGISTER_TIME = 225
100 100 _LISTENER_TIME = 200
101 101 _BROWSER_TIME = 500
102 102
103 103 # Some DNS constants
104 104
105 105 _MDNS_ADDR = '224.0.0.251'
106 106 _MDNS_PORT = 5353;
107 107 _DNS_PORT = 53;
108 108 _DNS_TTL = 60 * 60; # one hour default TTL
109 109
110 110 _MAX_MSG_TYPICAL = 1460 # unused
111 111 _MAX_MSG_ABSOLUTE = 8972
112 112
113 113 _FLAGS_QR_MASK = 0x8000 # query response mask
114 114 _FLAGS_QR_QUERY = 0x0000 # query
115 115 _FLAGS_QR_RESPONSE = 0x8000 # response
116 116
117 117 _FLAGS_AA = 0x0400 # Authoritative answer
118 118 _FLAGS_TC = 0x0200 # Truncated
119 119 _FLAGS_RD = 0x0100 # Recursion desired
120 120 _FLAGS_RA = 0x8000 # Recursion available
121 121
122 122 _FLAGS_Z = 0x0040 # Zero
123 123 _FLAGS_AD = 0x0020 # Authentic data
124 124 _FLAGS_CD = 0x0010 # Checking disabled
125 125
126 126 _CLASS_IN = 1
127 127 _CLASS_CS = 2
128 128 _CLASS_CH = 3
129 129 _CLASS_HS = 4
130 130 _CLASS_NONE = 254
131 131 _CLASS_ANY = 255
132 132 _CLASS_MASK = 0x7FFF
133 133 _CLASS_UNIQUE = 0x8000
134 134
135 135 _TYPE_A = 1
136 136 _TYPE_NS = 2
137 137 _TYPE_MD = 3
138 138 _TYPE_MF = 4
139 139 _TYPE_CNAME = 5
140 140 _TYPE_SOA = 6
141 141 _TYPE_MB = 7
142 142 _TYPE_MG = 8
143 143 _TYPE_MR = 9
144 144 _TYPE_NULL = 10
145 145 _TYPE_WKS = 11
146 146 _TYPE_PTR = 12
147 147 _TYPE_HINFO = 13
148 148 _TYPE_MINFO = 14
149 149 _TYPE_MX = 15
150 150 _TYPE_TXT = 16
151 151 _TYPE_AAAA = 28
152 152 _TYPE_SRV = 33
153 153 _TYPE_ANY = 255
154 154
155 155 # Mapping constants to names
156 156
157 157 _CLASSES = { _CLASS_IN : "in",
158 158 _CLASS_CS : "cs",
159 159 _CLASS_CH : "ch",
160 160 _CLASS_HS : "hs",
161 161 _CLASS_NONE : "none",
162 162 _CLASS_ANY : "any" }
163 163
164 164 _TYPES = { _TYPE_A : "a",
165 165 _TYPE_NS : "ns",
166 166 _TYPE_MD : "md",
167 167 _TYPE_MF : "mf",
168 168 _TYPE_CNAME : "cname",
169 169 _TYPE_SOA : "soa",
170 170 _TYPE_MB : "mb",
171 171 _TYPE_MG : "mg",
172 172 _TYPE_MR : "mr",
173 173 _TYPE_NULL : "null",
174 174 _TYPE_WKS : "wks",
175 175 _TYPE_PTR : "ptr",
176 176 _TYPE_HINFO : "hinfo",
177 177 _TYPE_MINFO : "minfo",
178 178 _TYPE_MX : "mx",
179 179 _TYPE_TXT : "txt",
180 180 _TYPE_AAAA : "quada",
181 181 _TYPE_SRV : "srv",
182 182 _TYPE_ANY : "any" }
183 183
184 184 # utility functions
185 185
186 186 def currentTimeMillis():
187 187 """Current system time in milliseconds"""
188 188 return time.time() * 1000
189 189
190 190 # Exceptions
191 191
192 192 class NonLocalNameException(Exception):
193 193 pass
194 194
195 195 class NonUniqueNameException(Exception):
196 196 pass
197 197
198 198 class NamePartTooLongException(Exception):
199 199 pass
200 200
201 201 class AbstractMethodException(Exception):
202 202 pass
203 203
204 204 class BadTypeInNameException(Exception):
205 205 pass
206 206
207 207 class BadDomainName(Exception):
208 208 def __init__(self, pos):
209 209 Exception.__init__(self, "at position %s" % pos)
210 210
211 211 class BadDomainNameCircular(BadDomainName):
212 212 pass
213 213
214 214 # implementation classes
215 215
216 216 class DNSEntry(object):
217 217 """A DNS entry"""
218 218
219 219 def __init__(self, name, type, clazz):
220 220 self.key = string.lower(name)
221 221 self.name = name
222 222 self.type = type
223 223 self.clazz = clazz & _CLASS_MASK
224 224 self.unique = (clazz & _CLASS_UNIQUE) != 0
225 225
226 226 def __eq__(self, other):
227 227 """Equality test on name, type, and class"""
228 228 if isinstance(other, DNSEntry):
229 229 return self.name == other.name and self.type == other.type and self.clazz == other.clazz
230 230 return 0
231 231
232 232 def __ne__(self, other):
233 233 """Non-equality test"""
234 234 return not self.__eq__(other)
235 235
236 236 def getClazz(self, clazz):
237 237 """Class accessor"""
238 238 try:
239 239 return _CLASSES[clazz]
240 240 except KeyError:
241 241 return "?(%s)" % (clazz)
242 242
243 243 def getType(self, type):
244 244 """Type accessor"""
245 245 try:
246 246 return _TYPES[type]
247 247 except KeyError:
248 248 return "?(%s)" % (type)
249 249
250 250 def toString(self, hdr, other):
251 251 """String representation with additional information"""
252 252 result = "%s[%s,%s" % (hdr, self.getType(self.type), self.getClazz(self.clazz))
253 253 if self.unique:
254 254 result += "-unique,"
255 255 else:
256 256 result += ","
257 257 result += self.name
258 258 if other is not None:
259 259 result += ",%s]" % (other)
260 260 else:
261 261 result += "]"
262 262 return result
263 263
264 264 class DNSQuestion(DNSEntry):
265 265 """A DNS question entry"""
266 266
267 267 def __init__(self, name, type, clazz):
268 268 if not name.endswith(".local."):
269 269 raise NonLocalNameException(name)
270 270 DNSEntry.__init__(self, name, type, clazz)
271 271
272 272 def answeredBy(self, rec):
273 273 """Returns true if the question is answered by the record"""
274 274 return self.clazz == rec.clazz and (self.type == rec.type or self.type == _TYPE_ANY) and self.name == rec.name
275 275
276 276 def __repr__(self):
277 277 """String representation"""
278 278 return DNSEntry.toString(self, "question", None)
279 279
280 280
281 281 class DNSRecord(DNSEntry):
282 282 """A DNS record - like a DNS entry, but has a TTL"""
283 283
284 284 def __init__(self, name, type, clazz, ttl):
285 285 DNSEntry.__init__(self, name, type, clazz)
286 286 self.ttl = ttl
287 287 self.created = currentTimeMillis()
288 288
289 289 def __eq__(self, other):
290 290 """Tests equality as per DNSRecord"""
291 291 if isinstance(other, DNSRecord):
292 292 return DNSEntry.__eq__(self, other)
293 293 return 0
294 294
295 295 def suppressedBy(self, msg):
296 296 """Returns true if any answer in a message can suffice for the
297 297 information held in this record."""
298 298 for record in msg.answers:
299 299 if self.suppressedByAnswer(record):
300 300 return 1
301 301 return 0
302 302
303 303 def suppressedByAnswer(self, other):
304 304 """Returns true if another record has same name, type and class,
305 305 and if its TTL is at least half of this record's."""
306 306 if self == other and other.ttl > (self.ttl / 2):
307 307 return 1
308 308 return 0
309 309
310 310 def getExpirationTime(self, percent):
311 311 """Returns the time at which this record will have expired
312 312 by a certain percentage."""
313 313 return self.created + (percent * self.ttl * 10)
314 314
315 315 def getRemainingTTL(self, now):
316 316 """Returns the remaining TTL in seconds."""
317 317 return max(0, (self.getExpirationTime(100) - now) / 1000)
318 318
319 319 def isExpired(self, now):
320 320 """Returns true if this record has expired."""
321 321 return self.getExpirationTime(100) <= now
322 322
323 323 def isStale(self, now):
324 324 """Returns true if this record is at least half way expired."""
325 325 return self.getExpirationTime(50) <= now
326 326
327 327 def resetTTL(self, other):
328 328 """Sets this record's TTL and created time to that of
329 329 another record."""
330 330 self.created = other.created
331 331 self.ttl = other.ttl
332 332
333 333 def write(self, out):
334 334 """Abstract method"""
335 335 raise AbstractMethodException
336 336
337 337 def toString(self, other):
338 338 """String representation with additional information"""
339 339 arg = "%s/%s,%s" % (self.ttl, self.getRemainingTTL(currentTimeMillis()), other)
340 340 return DNSEntry.toString(self, "record", arg)
341 341
342 342 class DNSAddress(DNSRecord):
343 343 """A DNS address record"""
344 344
345 345 def __init__(self, name, type, clazz, ttl, address):
346 346 DNSRecord.__init__(self, name, type, clazz, ttl)
347 347 self.address = address
348 348
349 349 def write(self, out):
350 350 """Used in constructing an outgoing packet"""
351 351 out.writeString(self.address, len(self.address))
352 352
353 353 def __eq__(self, other):
354 354 """Tests equality on address"""
355 355 if isinstance(other, DNSAddress):
356 356 return self.address == other.address
357 357 return 0
358 358
359 359 def __repr__(self):
360 360 """String representation"""
361 361 try:
362 362 return socket.inet_ntoa(self.address)
363 363 except Exception:
364 364 return self.address
365 365
366 366 class DNSHinfo(DNSRecord):
367 367 """A DNS host information record"""
368 368
369 369 def __init__(self, name, type, clazz, ttl, cpu, os):
370 370 DNSRecord.__init__(self, name, type, clazz, ttl)
371 371 self.cpu = cpu
372 372 self.os = os
373 373
374 374 def write(self, out):
375 375 """Used in constructing an outgoing packet"""
376 376 out.writeString(self.cpu, len(self.cpu))
377 377 out.writeString(self.os, len(self.os))
378 378
379 379 def __eq__(self, other):
380 380 """Tests equality on cpu and os"""
381 381 if isinstance(other, DNSHinfo):
382 382 return self.cpu == other.cpu and self.os == other.os
383 383 return 0
384 384
385 385 def __repr__(self):
386 386 """String representation"""
387 387 return self.cpu + " " + self.os
388 388
389 389 class DNSPointer(DNSRecord):
390 390 """A DNS pointer record"""
391 391
392 392 def __init__(self, name, type, clazz, ttl, alias):
393 393 DNSRecord.__init__(self, name, type, clazz, ttl)
394 394 self.alias = alias
395 395
396 396 def write(self, out):
397 397 """Used in constructing an outgoing packet"""
398 398 out.writeName(self.alias)
399 399
400 400 def __eq__(self, other):
401 401 """Tests equality on alias"""
402 402 if isinstance(other, DNSPointer):
403 403 return self.alias == other.alias
404 404 return 0
405 405
406 406 def __repr__(self):
407 407 """String representation"""
408 408 return self.toString(self.alias)
409 409
410 410 class DNSText(DNSRecord):
411 411 """A DNS text record"""
412 412
413 413 def __init__(self, name, type, clazz, ttl, text):
414 414 DNSRecord.__init__(self, name, type, clazz, ttl)
415 415 self.text = text
416 416
417 417 def write(self, out):
418 418 """Used in constructing an outgoing packet"""
419 419 out.writeString(self.text, len(self.text))
420 420
421 421 def __eq__(self, other):
422 422 """Tests equality on text"""
423 423 if isinstance(other, DNSText):
424 424 return self.text == other.text
425 425 return 0
426 426
427 427 def __repr__(self):
428 428 """String representation"""
429 429 if len(self.text) > 10:
430 430 return self.toString(self.text[:7] + "...")
431 431 else:
432 432 return self.toString(self.text)
433 433
434 434 class DNSService(DNSRecord):
435 435 """A DNS service record"""
436 436
437 437 def __init__(self, name, type, clazz, ttl, priority, weight, port, server):
438 438 DNSRecord.__init__(self, name, type, clazz, ttl)
439 439 self.priority = priority
440 440 self.weight = weight
441 441 self.port = port
442 442 self.server = server
443 443
444 444 def write(self, out):
445 445 """Used in constructing an outgoing packet"""
446 446 out.writeShort(self.priority)
447 447 out.writeShort(self.weight)
448 448 out.writeShort(self.port)
449 449 out.writeName(self.server)
450 450
451 451 def __eq__(self, other):
452 452 """Tests equality on priority, weight, port and server"""
453 453 if isinstance(other, DNSService):
454 454 return self.priority == other.priority and self.weight == other.weight and self.port == other.port and self.server == other.server
455 455 return 0
456 456
457 457 def __repr__(self):
458 458 """String representation"""
459 459 return self.toString("%s:%s" % (self.server, self.port))
460 460
461 461 class DNSIncoming(object):
462 462 """Object representation of an incoming DNS packet"""
463 463
464 464 def __init__(self, data):
465 465 """Constructor from string holding bytes of packet"""
466 466 self.offset = 0
467 467 self.data = data
468 468 self.questions = []
469 469 self.answers = []
470 470 self.numQuestions = 0
471 471 self.numAnswers = 0
472 472 self.numAuthorities = 0
473 473 self.numAdditionals = 0
474 474
475 475 self.readHeader()
476 476 self.readQuestions()
477 477 self.readOthers()
478 478
479 479 def readHeader(self):
480 480 """Reads header portion of packet"""
481 481 format = '!HHHHHH'
482 482 length = struct.calcsize(format)
483 483 info = struct.unpack(format, self.data[self.offset:self.offset+length])
484 484 self.offset += length
485 485
486 486 self.id = info[0]
487 487 self.flags = info[1]
488 488 self.numQuestions = info[2]
489 489 self.numAnswers = info[3]
490 490 self.numAuthorities = info[4]
491 491 self.numAdditionals = info[5]
492 492
493 493 def readQuestions(self):
494 494 """Reads questions section of packet"""
495 495 format = '!HH'
496 496 length = struct.calcsize(format)
497 497 for i in range(0, self.numQuestions):
498 498 name = self.readName()
499 499 info = struct.unpack(format, self.data[self.offset:self.offset+length])
500 500 self.offset += length
501 501
502 502 try:
503 503 question = DNSQuestion(name, info[0], info[1])
504 504 self.questions.append(question)
505 505 except NonLocalNameException:
506 506 pass
507 507
508 508 def readInt(self):
509 509 """Reads an integer from the packet"""
510 510 format = '!I'
511 511 length = struct.calcsize(format)
512 512 info = struct.unpack(format, self.data[self.offset:self.offset+length])
513 513 self.offset += length
514 514 return info[0]
515 515
516 516 def readCharacterString(self):
517 517 """Reads a character string from the packet"""
518 518 length = ord(self.data[self.offset])
519 519 self.offset += 1
520 520 return self.readString(length)
521 521
522 522 def readString(self, len):
523 523 """Reads a string of a given length from the packet"""
524 524 format = '!' + str(len) + 's'
525 525 length = struct.calcsize(format)
526 526 info = struct.unpack(format, self.data[self.offset:self.offset+length])
527 527 self.offset += length
528 528 return info[0]
529 529
530 530 def readUnsignedShort(self):
531 531 """Reads an unsigned short from the packet"""
532 532 format = '!H'
533 533 length = struct.calcsize(format)
534 534 info = struct.unpack(format, self.data[self.offset:self.offset+length])
535 535 self.offset += length
536 536 return info[0]
537 537
538 538 def readOthers(self):
539 539 """Reads the answers, authorities and additionals section of the packet"""
540 540 format = '!HHiH'
541 541 length = struct.calcsize(format)
542 542 n = self.numAnswers + self.numAuthorities + self.numAdditionals
543 543 for i in range(0, n):
544 544 domain = self.readName()
545 545 info = struct.unpack(format, self.data[self.offset:self.offset+length])
546 546 self.offset += length
547 547
548 548 rec = None
549 549 if info[0] == _TYPE_A:
550 550 rec = DNSAddress(domain, info[0], info[1], info[2], self.readString(4))
551 551 elif info[0] == _TYPE_CNAME or info[0] == _TYPE_PTR:
552 552 rec = DNSPointer(domain, info[0], info[1], info[2], self.readName())
553 553 elif info[0] == _TYPE_TXT:
554 554 rec = DNSText(domain, info[0], info[1], info[2], self.readString(info[3]))
555 555 elif info[0] == _TYPE_SRV:
556 556 rec = DNSService(domain, info[0], info[1], info[2], self.readUnsignedShort(), self.readUnsignedShort(), self.readUnsignedShort(), self.readName())
557 557 elif info[0] == _TYPE_HINFO:
558 558 rec = DNSHinfo(domain, info[0], info[1], info[2], self.readCharacterString(), self.readCharacterString())
559 559 elif info[0] == _TYPE_AAAA:
560 560 rec = DNSAddress(domain, info[0], info[1], info[2], self.readString(16))
561 561 else:
562 562 # Try to ignore types we don't know about
563 563 # this may mean the rest of the name is
564 564 # unable to be parsed, and may show errors
565 565 # so this is left for debugging. New types
566 566 # encountered need to be parsed properly.
567 567 #
568 568 #print "UNKNOWN TYPE = " + str(info[0])
569 569 #raise BadTypeInNameException
570 570 self.offset += info[3]
571 571
572 572 if rec is not None:
573 573 self.answers.append(rec)
574 574
575 575 def isQuery(self):
576 576 """Returns true if this is a query"""
577 577 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY
578 578
579 579 def isResponse(self):
580 580 """Returns true if this is a response"""
581 581 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
582 582
583 583 def readUTF(self, offset, len):
584 584 """Reads a UTF-8 string of a given length from the packet"""
585 585 return self.data[offset:offset+len].decode('utf-8')
586 586
587 587 def readName(self):
588 588 """Reads a domain name from the packet"""
589 589 result = ''
590 590 off = self.offset
591 591 next = -1
592 592 first = off
593 593
594 594 while True:
595 595 len = ord(self.data[off])
596 596 off += 1
597 597 if len == 0:
598 598 break
599 599 t = len & 0xC0
600 600 if t == 0x00:
601 601 result = ''.join((result, self.readUTF(off, len) + '.'))
602 602 off += len
603 603 elif t == 0xC0:
604 604 if next < 0:
605 605 next = off + 1
606 606 off = ((len & 0x3F) << 8) | ord(self.data[off])
607 607 if off >= first:
608 608 raise BadDomainNameCircular(off)
609 609 first = off
610 610 else:
611 611 raise BadDomainName(off)
612 612
613 613 if next >= 0:
614 614 self.offset = next
615 615 else:
616 616 self.offset = off
617 617
618 618 return result
619 619
620 620
621 621 class DNSOutgoing(object):
622 622 """Object representation of an outgoing packet"""
623 623
624 624 def __init__(self, flags, multicast = 1):
625 625 self.finished = 0
626 626 self.id = 0
627 627 self.multicast = multicast
628 628 self.flags = flags
629 629 self.names = {}
630 630 self.data = []
631 631 self.size = 12
632 632
633 633 self.questions = []
634 634 self.answers = []
635 635 self.authorities = []
636 636 self.additionals = []
637 637
638 638 def addQuestion(self, record):
639 639 """Adds a question"""
640 640 self.questions.append(record)
641 641
642 642 def addAnswer(self, inp, record):
643 643 """Adds an answer"""
644 644 if not record.suppressedBy(inp):
645 645 self.addAnswerAtTime(record, 0)
646 646
647 647 def addAnswerAtTime(self, record, now):
648 648 """Adds an answer if if does not expire by a certain time"""
649 649 if record is not None:
650 650 if now == 0 or not record.isExpired(now):
651 651 self.answers.append((record, now))
652 652
653 653 def addAuthoritativeAnswer(self, record):
654 654 """Adds an authoritative answer"""
655 655 self.authorities.append(record)
656 656
657 657 def addAdditionalAnswer(self, record):
658 658 """Adds an additional answer"""
659 659 self.additionals.append(record)
660 660
661 661 def writeByte(self, value):
662 662 """Writes a single byte to the packet"""
663 663 format = '!c'
664 664 self.data.append(struct.pack(format, chr(value)))
665 665 self.size += 1
666 666
667 667 def insertShort(self, index, value):
668 668 """Inserts an unsigned short in a certain position in the packet"""
669 669 format = '!H'
670 670 self.data.insert(index, struct.pack(format, value))
671 671 self.size += 2
672 672
673 673 def writeShort(self, value):
674 674 """Writes an unsigned short to the packet"""
675 675 format = '!H'
676 676 self.data.append(struct.pack(format, value))
677 677 self.size += 2
678 678
679 679 def writeInt(self, value):
680 680 """Writes an unsigned integer to the packet"""
681 681 format = '!I'
682 682 self.data.append(struct.pack(format, int(value)))
683 683 self.size += 4
684 684
685 685 def writeString(self, value, length):
686 686 """Writes a string to the packet"""
687 687 format = '!' + str(length) + 's'
688 688 self.data.append(struct.pack(format, value))
689 689 self.size += length
690 690
691 691 def writeUTF(self, s):
692 692 """Writes a UTF-8 string of a given length to the packet"""
693 693 utfstr = s.encode('utf-8')
694 694 length = len(utfstr)
695 695 if length > 64:
696 696 raise NamePartTooLongException
697 697 self.writeByte(length)
698 698 self.writeString(utfstr, length)
699 699
700 700 def writeName(self, name):
701 701 """Writes a domain name to the packet"""
702 702
703 703 try:
704 704 # Find existing instance of this name in packet
705 705 #
706 706 index = self.names[name]
707 707 except KeyError:
708 708 # No record of this name already, so write it
709 709 # out as normal, recording the location of the name
710 710 # for future pointers to it.
711 711 #
712 712 self.names[name] = self.size
713 713 parts = name.split('.')
714 714 if parts[-1] == '':
715 715 parts = parts[:-1]
716 716 for part in parts:
717 717 self.writeUTF(part)
718 718 self.writeByte(0)
719 719 return
720 720
721 721 # An index was found, so write a pointer to it
722 722 #
723 723 self.writeByte((index >> 8) | 0xC0)
724 724 self.writeByte(index)
725 725
726 726 def writeQuestion(self, question):
727 727 """Writes a question to the packet"""
728 728 self.writeName(question.name)
729 729 self.writeShort(question.type)
730 730 self.writeShort(question.clazz)
731 731
732 732 def writeRecord(self, record, now):
733 733 """Writes a record (answer, authoritative answer, additional) to
734 734 the packet"""
735 735 self.writeName(record.name)
736 736 self.writeShort(record.type)
737 737 if record.unique and self.multicast:
738 738 self.writeShort(record.clazz | _CLASS_UNIQUE)
739 739 else:
740 740 self.writeShort(record.clazz)
741 741 if now == 0:
742 742 self.writeInt(record.ttl)
743 743 else:
744 744 self.writeInt(record.getRemainingTTL(now))
745 745 index = len(self.data)
746 746 # Adjust size for the short we will write before this record
747 747 #
748 748 self.size += 2
749 749 record.write(self)
750 750 self.size -= 2
751 751
752 752 length = len(''.join(self.data[index:]))
753 753 self.insertShort(index, length) # Here is the short we adjusted for
754 754
755 755 def packet(self):
756 756 """Returns a string containing the packet's bytes
757 757
758 758 No further parts should be added to the packet once this
759 759 is done."""
760 760 if not self.finished:
761 761 self.finished = 1
762 762 for question in self.questions:
763 763 self.writeQuestion(question)
764 764 for answer, time in self.answers:
765 765 self.writeRecord(answer, time)
766 766 for authority in self.authorities:
767 767 self.writeRecord(authority, 0)
768 768 for additional in self.additionals:
769 769 self.writeRecord(additional, 0)
770 770
771 771 self.insertShort(0, len(self.additionals))
772 772 self.insertShort(0, len(self.authorities))
773 773 self.insertShort(0, len(self.answers))
774 774 self.insertShort(0, len(self.questions))
775 775 self.insertShort(0, self.flags)
776 776 if self.multicast:
777 777 self.insertShort(0, 0)
778 778 else:
779 779 self.insertShort(0, self.id)
780 780 return ''.join(self.data)
781 781
782 782
783 783 class DNSCache(object):
784 784 """A cache of DNS entries"""
785 785
786 786 def __init__(self):
787 787 self.cache = {}
788 788
789 789 def add(self, entry):
790 790 """Adds an entry"""
791 791 try:
792 792 list = self.cache[entry.key]
793 793 except KeyError:
794 794 list = self.cache[entry.key] = []
795 795 list.append(entry)
796 796
797 797 def remove(self, entry):
798 798 """Removes an entry"""
799 799 try:
800 800 list = self.cache[entry.key]
801 801 list.remove(entry)
802 802 except KeyError:
803 803 pass
804 804
805 805 def get(self, entry):
806 806 """Gets an entry by key. Will return None if there is no
807 807 matching entry."""
808 808 try:
809 809 list = self.cache[entry.key]
810 810 return list[list.index(entry)]
811 811 except (KeyError, ValueError):
812 812 return None
813 813
814 814 def getByDetails(self, name, type, clazz):
815 815 """Gets an entry by details. Will return None if there is
816 816 no matching entry."""
817 817 entry = DNSEntry(name, type, clazz)
818 818 return self.get(entry)
819 819
820 820 def entriesWithName(self, name):
821 821 """Returns a list of entries whose key matches the name."""
822 822 try:
823 823 return self.cache[name]
824 824 except KeyError:
825 825 return []
826 826
827 827 def entries(self):
828 828 """Returns a list of all entries"""
829 829 def add(x, y): return x+y
830 830 try:
831 831 return reduce(add, self.cache.values())
832 832 except Exception:
833 833 return []
834 834
835 835
836 836 class Engine(threading.Thread):
837 837 """An engine wraps read access to sockets, allowing objects that
838 838 need to receive data from sockets to be called back when the
839 839 sockets are ready.
840 840
841 841 A reader needs a handle_read() method, which is called when the socket
842 842 it is interested in is ready for reading.
843 843
844 844 Writers are not implemented here, because we only send short
845 845 packets.
846 846 """
847 847
848 848 def __init__(self, zeroconf):
849 849 threading.Thread.__init__(self)
850 850 self.zeroconf = zeroconf
851 851 self.readers = {} # maps socket to reader
852 852 self.timeout = 5
853 853 self.condition = threading.Condition()
854 854 self.start()
855 855
856 856 def run(self):
857 857 while not globals()['_GLOBAL_DONE']:
858 858 rs = self.getReaders()
859 859 if len(rs) == 0:
860 860 # No sockets to manage, but we wait for the timeout
861 861 # or addition of a socket
862 862 #
863 863 self.condition.acquire()
864 864 self.condition.wait(self.timeout)
865 865 self.condition.release()
866 866 else:
867 867 try:
868 868 rr, wr, er = select.select(rs, [], [], self.timeout)
869 869 for socket in rr:
870 870 try:
871 871 self.readers[socket].handle_read()
872 872 except Exception:
873 873 if not globals()['_GLOBAL_DONE']:
874 874 traceback.print_exc()
875 875 except Exception:
876 876 pass
877 877
878 878 def getReaders(self):
879 879 self.condition.acquire()
880 880 result = self.readers.keys()
881 881 self.condition.release()
882 882 return result
883 883
884 884 def addReader(self, reader, socket):
885 885 self.condition.acquire()
886 886 self.readers[socket] = reader
887 887 self.condition.notify()
888 888 self.condition.release()
889 889
890 890 def delReader(self, socket):
891 891 self.condition.acquire()
892 892 del(self.readers[socket])
893 893 self.condition.notify()
894 894 self.condition.release()
895 895
896 896 def notify(self):
897 897 self.condition.acquire()
898 898 self.condition.notify()
899 899 self.condition.release()
900 900
901 901 class Listener(object):
902 902 """A Listener is used by this module to listen on the multicast
903 903 group to which DNS messages are sent, allowing the implementation
904 904 to cache information as it arrives.
905 905
906 906 It requires registration with an Engine object in order to have
907 907 the read() method called when a socket is available for reading."""
908 908
909 909 def __init__(self, zeroconf):
910 910 self.zeroconf = zeroconf
911 911 self.zeroconf.engine.addReader(self, self.zeroconf.socket)
912 912
913 913 def handle_read(self):
914 914 data, (addr, port) = self.zeroconf.socket.recvfrom(_MAX_MSG_ABSOLUTE)
915 915 self.data = data
916 916 msg = DNSIncoming(data)
917 917 if msg.isQuery():
918 918 # Always multicast responses
919 919 #
920 920 if port == _MDNS_PORT:
921 921 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT)
922 922 # If it's not a multicast query, reply via unicast
923 923 # and multicast
924 924 #
925 925 elif port == _DNS_PORT:
926 926 self.zeroconf.handleQuery(msg, addr, port)
927 927 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT)
928 928 else:
929 929 self.zeroconf.handleResponse(msg)
930 930
931 931
932 932 class Reaper(threading.Thread):
933 933 """A Reaper is used by this module to remove cache entries that
934 934 have expired."""
935 935
936 936 def __init__(self, zeroconf):
937 937 threading.Thread.__init__(self)
938 938 self.zeroconf = zeroconf
939 939 self.start()
940 940
941 941 def run(self):
942 942 while True:
943 943 self.zeroconf.wait(10 * 1000)
944 944 if globals()['_GLOBAL_DONE']:
945 945 return
946 946 now = currentTimeMillis()
947 947 for record in self.zeroconf.cache.entries():
948 948 if record.isExpired(now):
949 949 self.zeroconf.updateRecord(now, record)
950 950 self.zeroconf.cache.remove(record)
951 951
952 952
953 953 class ServiceBrowser(threading.Thread):
954 954 """Used to browse for a service of a specific type.
955 955
956 956 The listener object will have its addService() and
957 957 removeService() methods called when this browser
958 958 discovers changes in the services availability."""
959 959
960 960 def __init__(self, zeroconf, type, listener):
961 961 """Creates a browser for a specific type"""
962 962 threading.Thread.__init__(self)
963 963 self.zeroconf = zeroconf
964 964 self.type = type
965 965 self.listener = listener
966 966 self.services = {}
967 967 self.nextTime = currentTimeMillis()
968 968 self.delay = _BROWSER_TIME
969 969 self.list = []
970 970
971 971 self.done = 0
972 972
973 973 self.zeroconf.addListener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
974 974 self.start()
975 975
976 976 def updateRecord(self, zeroconf, now, record):
977 977 """Callback invoked by Zeroconf when new information arrives.
978 978
979 979 Updates information required by browser in the Zeroconf cache."""
980 980 if record.type == _TYPE_PTR and record.name == self.type:
981 981 expired = record.isExpired(now)
982 982 try:
983 983 oldrecord = self.services[record.alias.lower()]
984 984 if not expired:
985 985 oldrecord.resetTTL(record)
986 986 else:
987 987 del(self.services[record.alias.lower()])
988 988 callback = lambda x: self.listener.removeService(x, self.type, record.alias)
989 989 self.list.append(callback)
990 990 return
991 991 except Exception:
992 992 if not expired:
993 993 self.services[record.alias.lower()] = record
994 994 callback = lambda x: self.listener.addService(x, self.type, record.alias)
995 995 self.list.append(callback)
996 996
997 997 expires = record.getExpirationTime(75)
998 998 if expires < self.nextTime:
999 999 self.nextTime = expires
1000 1000
1001 1001 def cancel(self):
1002 1002 self.done = 1
1003 1003 self.zeroconf.notifyAll()
1004 1004
1005 1005 def run(self):
1006 1006 while True:
1007 1007 event = None
1008 1008 now = currentTimeMillis()
1009 1009 if len(self.list) == 0 and self.nextTime > now:
1010 1010 self.zeroconf.wait(self.nextTime - now)
1011 1011 if globals()['_GLOBAL_DONE'] or self.done:
1012 1012 return
1013 1013 now = currentTimeMillis()
1014 1014
1015 1015 if self.nextTime <= now:
1016 1016 out = DNSOutgoing(_FLAGS_QR_QUERY)
1017 1017 out.addQuestion(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
1018 1018 for record in self.services.values():
1019 1019 if not record.isExpired(now):
1020 1020 out.addAnswerAtTime(record, now)
1021 1021 self.zeroconf.send(out)
1022 1022 self.nextTime = now + self.delay
1023 1023 self.delay = min(20 * 1000, self.delay * 2)
1024 1024
1025 1025 if len(self.list) > 0:
1026 1026 event = self.list.pop(0)
1027 1027
1028 1028 if event is not None:
1029 1029 event(self.zeroconf)
1030 1030
1031 1031
1032 1032 class ServiceInfo(object):
1033 1033 """Service information"""
1034 1034
1035 1035 def __init__(self, type, name, address=None, port=None, weight=0, priority=0, properties=None, server=None):
1036 1036 """Create a service description.
1037 1037
1038 1038 type: fully qualified service type name
1039 1039 name: fully qualified service name
1040 1040 address: IP address as unsigned short, network byte order
1041 1041 port: port that the service runs on
1042 1042 weight: weight of the service
1043 1043 priority: priority of the service
1044 1044 properties: dictionary of properties (or a string holding the bytes for the text field)
1045 1045 server: fully qualified name for service host (defaults to name)"""
1046 1046
1047 1047 if not name.endswith(type):
1048 1048 raise BadTypeInNameException
1049 1049 self.type = type
1050 1050 self.name = name
1051 1051 self.address = address
1052 1052 self.port = port
1053 1053 self.weight = weight
1054 1054 self.priority = priority
1055 1055 if server:
1056 1056 self.server = server
1057 1057 else:
1058 1058 self.server = name
1059 1059 self.setProperties(properties)
1060 1060
1061 1061 def setProperties(self, properties):
1062 1062 """Sets properties and text of this info from a dictionary"""
1063 1063 if isinstance(properties, dict):
1064 1064 self.properties = properties
1065 1065 list = []
1066 1066 result = ''
1067 1067 for key in properties:
1068 1068 value = properties[key]
1069 1069 if value is None:
1070 1070 suffix = ''
1071 1071 elif isinstance(value, str):
1072 1072 suffix = value
1073 1073 elif isinstance(value, int):
1074 1074 if value:
1075 1075 suffix = 'true'
1076 1076 else:
1077 1077 suffix = 'false'
1078 1078 else:
1079 1079 suffix = ''
1080 1080 list.append('='.join((key, suffix)))
1081 1081 for item in list:
1082 1082 result = ''.join((result, struct.pack('!c', chr(len(item))), item))
1083 1083 self.text = result
1084 1084 else:
1085 1085 self.text = properties
1086 1086
1087 1087 def setText(self, text):
1088 1088 """Sets properties and text given a text field"""
1089 1089 self.text = text
1090 1090 try:
1091 1091 result = {}
1092 1092 end = len(text)
1093 1093 index = 0
1094 1094 strs = []
1095 1095 while index < end:
1096 1096 length = ord(text[index])
1097 1097 index += 1
1098 1098 strs.append(text[index:index+length])
1099 1099 index += length
1100 1100
1101 1101 for s in strs:
1102 1102 eindex = s.find('=')
1103 1103 if eindex == -1:
1104 1104 # No equals sign at all
1105 1105 key = s
1106 1106 value = 0
1107 1107 else:
1108 1108 key = s[:eindex]
1109 1109 value = s[eindex+1:]
1110 1110 if value == 'true':
1111 1111 value = 1
1112 1112 elif value == 'false' or not value:
1113 1113 value = 0
1114 1114
1115 1115 # Only update non-existent properties
1116 1116 if key and result.get(key) == None:
1117 1117 result[key] = value
1118 1118
1119 1119 self.properties = result
1120 1120 except Exception:
1121 1121 traceback.print_exc()
1122 1122 self.properties = None
1123 1123
1124 1124 def getType(self):
1125 1125 """Type accessor"""
1126 1126 return self.type
1127 1127
1128 1128 def getName(self):
1129 1129 """Name accessor"""
1130 1130 if self.type is not None and self.name.endswith("." + self.type):
1131 1131 return self.name[:len(self.name) - len(self.type) - 1]
1132 1132 return self.name
1133 1133
1134 1134 def getAddress(self):
1135 1135 """Address accessor"""
1136 1136 return self.address
1137 1137
1138 1138 def getPort(self):
1139 1139 """Port accessor"""
1140 1140 return self.port
1141 1141
1142 1142 def getPriority(self):
1143 1143 """Priority accessor"""
1144 1144 return self.priority
1145 1145
1146 1146 def getWeight(self):
1147 1147 """Weight accessor"""
1148 1148 return self.weight
1149 1149
1150 1150 def getProperties(self):
1151 1151 """Properties accessor"""
1152 1152 return self.properties
1153 1153
1154 1154 def getText(self):
1155 1155 """Text accessor"""
1156 1156 return self.text
1157 1157
1158 1158 def getServer(self):
1159 1159 """Server accessor"""
1160 1160 return self.server
1161 1161
1162 1162 def updateRecord(self, zeroconf, now, record):
1163 1163 """Updates service information from a DNS record"""
1164 1164 if record is not None and not record.isExpired(now):
1165 1165 if record.type == _TYPE_A:
1166 1166 #if record.name == self.name:
1167 1167 if record.name == self.server:
1168 1168 self.address = record.address
1169 1169 elif record.type == _TYPE_SRV:
1170 1170 if record.name == self.name:
1171 1171 self.server = record.server
1172 1172 self.port = record.port
1173 1173 self.weight = record.weight
1174 1174 self.priority = record.priority
1175 1175 #self.address = None
1176 1176 self.updateRecord(zeroconf, now, zeroconf.cache.getByDetails(self.server, _TYPE_A, _CLASS_IN))
1177 1177 elif record.type == _TYPE_TXT:
1178 1178 if record.name == self.name:
1179 1179 self.setText(record.text)
1180 1180
1181 1181 def request(self, zeroconf, timeout):
1182 1182 """Returns true if the service could be discovered on the
1183 1183 network, and updates this object with details discovered.
1184 1184 """
1185 1185 now = currentTimeMillis()
1186 1186 delay = _LISTENER_TIME
1187 1187 next = now + delay
1188 1188 last = now + timeout
1189 1189 result = 0
1190 1190 try:
1191 1191 zeroconf.addListener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN))
1192 1192 while self.server is None or self.address is None or self.text is None:
1193 1193 if last <= now:
1194 1194 return 0
1195 1195 if next <= now:
1196 1196 out = DNSOutgoing(_FLAGS_QR_QUERY)
1197 1197 out.addQuestion(DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN))
1198 1198 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.name, _TYPE_SRV, _CLASS_IN), now)
1199 1199 out.addQuestion(DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN))
1200 1200 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.name, _TYPE_TXT, _CLASS_IN), now)
1201 1201 if self.server is not None:
1202 1202 out.addQuestion(DNSQuestion(self.server, _TYPE_A, _CLASS_IN))
1203 1203 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.server, _TYPE_A, _CLASS_IN), now)
1204 1204 zeroconf.send(out)
1205 1205 next = now + delay
1206 1206 delay = delay * 2
1207 1207
1208 1208 zeroconf.wait(min(next, last) - now)
1209 1209 now = currentTimeMillis()
1210 1210 result = 1
1211 1211 finally:
1212 1212 zeroconf.removeListener(self)
1213 1213
1214 1214 return result
1215 1215
1216 1216 def __eq__(self, other):
1217 1217 """Tests equality of service name"""
1218 1218 if isinstance(other, ServiceInfo):
1219 1219 return other.name == self.name
1220 1220 return 0
1221 1221
1222 1222 def __ne__(self, other):
1223 1223 """Non-equality test"""
1224 1224 return not self.__eq__(other)
1225 1225
1226 1226 def __repr__(self):
1227 1227 """String representation"""
1228 1228 result = "service[%s,%s:%s," % (self.name, socket.inet_ntoa(self.getAddress()), self.port)
1229 1229 if self.text is None:
1230 1230 result += "None"
1231 1231 else:
1232 1232 if len(self.text) < 20:
1233 1233 result += self.text
1234 1234 else:
1235 1235 result += self.text[:17] + "..."
1236 1236 result += "]"
1237 1237 return result
1238 1238
1239 1239
1240 1240 class Zeroconf(object):
1241 1241 """Implementation of Zeroconf Multicast DNS Service Discovery
1242 1242
1243 1243 Supports registration, unregistration, queries and browsing.
1244 1244 """
1245 1245 def __init__(self, bindaddress=None):
1246 1246 """Creates an instance of the Zeroconf class, establishing
1247 1247 multicast communications, listening and reaping threads."""
1248 1248 globals()['_GLOBAL_DONE'] = 0
1249 1249 if bindaddress is None:
1250 1250 self.intf = socket.gethostbyname(socket.gethostname())
1251 1251 else:
1252 1252 self.intf = bindaddress
1253 1253 self.group = ('', _MDNS_PORT)
1254 1254 self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1255 1255 try:
1256 1256 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1257 1257 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1258 1258 except Exception:
1259 1259 # SO_REUSEADDR should be equivalent to SO_REUSEPORT for
1260 1260 # multicast UDP sockets (p 731, "TCP/IP Illustrated,
1261 1261 # Volume 2"), but some BSD-derived systems require
1262 1262 # SO_REUSEPORT to be specified explicitly. Also, not all
1263 1263 # versions of Python have SO_REUSEPORT available. So
1264 1264 # if you're on a BSD-based system, and haven't upgraded
1265 1265 # to Python 2.3 yet, you may find this library doesn't
1266 1266 # work as expected.
1267 1267 #
1268 1268 pass
1269 1269 self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 255)
1270 1270 self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1)
1271 1271 try:
1272 1272 self.socket.bind(self.group)
1273 1273 except Exception:
1274 1274 # Some versions of linux raise an exception even though
1275 1275 # the SO_REUSE* options have been set, so ignore it
1276 1276 #
1277 1277 pass
1278 #self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(self.intf) + socket.inet_aton('0.0.0.0'))
1279 1278 self.socket.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0'))
1280 1279
1281 1280 self.listeners = []
1282 1281 self.browsers = []
1283 1282 self.services = {}
1284 1283 self.servicetypes = {}
1285 1284
1286 1285 self.cache = DNSCache()
1287 1286
1288 1287 self.condition = threading.Condition()
1289 1288
1290 1289 self.engine = Engine(self)
1291 1290 self.listener = Listener(self)
1292 1291 self.reaper = Reaper(self)
1293 1292
1294 1293 def isLoopback(self):
1295 1294 return self.intf.startswith("127.0.0.1")
1296 1295
1297 1296 def isLinklocal(self):
1298 1297 return self.intf.startswith("169.254.")
1299 1298
1300 1299 def wait(self, timeout):
1301 1300 """Calling thread waits for a given number of milliseconds or
1302 1301 until notified."""
1303 1302 self.condition.acquire()
1304 1303 self.condition.wait(timeout/1000)
1305 1304 self.condition.release()
1306 1305
1307 1306 def notifyAll(self):
1308 1307 """Notifies all waiting threads"""
1309 1308 self.condition.acquire()
1310 1309 self.condition.notifyAll()
1311 1310 self.condition.release()
1312 1311
1313 1312 def getServiceInfo(self, type, name, timeout=3000):
1314 1313 """Returns network's service information for a particular
1315 1314 name and type, or None if no service matches by the timeout,
1316 1315 which defaults to 3 seconds."""
1317 1316 info = ServiceInfo(type, name)
1318 1317 if info.request(self, timeout):
1319 1318 return info
1320 1319 return None
1321 1320
1322 1321 def addServiceListener(self, type, listener):
1323 1322 """Adds a listener for a particular service type. This object
1324 1323 will then have its updateRecord method called when information
1325 1324 arrives for that type."""
1326 1325 self.removeServiceListener(listener)
1327 1326 self.browsers.append(ServiceBrowser(self, type, listener))
1328 1327
1329 1328 def removeServiceListener(self, listener):
1330 1329 """Removes a listener from the set that is currently listening."""
1331 1330 for browser in self.browsers:
1332 1331 if browser.listener == listener:
1333 1332 browser.cancel()
1334 1333 del(browser)
1335 1334
1336 1335 def registerService(self, info, ttl=_DNS_TTL):
1337 1336 """Registers service information to the network with a default TTL
1338 1337 of 60 seconds. Zeroconf will then respond to requests for
1339 1338 information for that service. The name of the service may be
1340 1339 changed if needed to make it unique on the network."""
1341 1340 self.checkService(info)
1342 1341 self.services[info.name.lower()] = info
1343 1342 if self.servicetypes.has_key(info.type):
1344 1343 self.servicetypes[info.type]+=1
1345 1344 else:
1346 1345 self.servicetypes[info.type]=1
1347 1346 now = currentTimeMillis()
1348 1347 nextTime = now
1349 1348 i = 0
1350 1349 while i < 3:
1351 1350 if now < nextTime:
1352 1351 self.wait(nextTime - now)
1353 1352 now = currentTimeMillis()
1354 1353 continue
1355 1354 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1356 1355 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, ttl, info.name), 0)
1357 1356 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, ttl, info.priority, info.weight, info.port, info.server), 0)
1358 1357 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, ttl, info.text), 0)
1359 1358 if info.address:
1360 1359 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, ttl, info.address), 0)
1361 1360 self.send(out)
1362 1361 i += 1
1363 1362 nextTime += _REGISTER_TIME
1364 1363
1365 1364 def unregisterService(self, info):
1366 1365 """Unregister a service."""
1367 1366 try:
1368 1367 del(self.services[info.name.lower()])
1369 1368 if self.servicetypes[info.type]>1:
1370 1369 self.servicetypes[info.type]-=1
1371 1370 else:
1372 1371 del self.servicetypes[info.type]
1373 1372 except KeyError:
1374 1373 pass
1375 1374 now = currentTimeMillis()
1376 1375 nextTime = now
1377 1376 i = 0
1378 1377 while i < 3:
1379 1378 if now < nextTime:
1380 1379 self.wait(nextTime - now)
1381 1380 now = currentTimeMillis()
1382 1381 continue
1383 1382 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1384 1383 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
1385 1384 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0, info.priority, info.weight, info.port, info.name), 0)
1386 1385 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
1387 1386 if info.address:
1388 1387 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0)
1389 1388 self.send(out)
1390 1389 i += 1
1391 1390 nextTime += _UNREGISTER_TIME
1392 1391
1393 1392 def unregisterAllServices(self):
1394 1393 """Unregister all registered services."""
1395 1394 if len(self.services) > 0:
1396 1395 now = currentTimeMillis()
1397 1396 nextTime = now
1398 1397 i = 0
1399 1398 while i < 3:
1400 1399 if now < nextTime:
1401 1400 self.wait(nextTime - now)
1402 1401 now = currentTimeMillis()
1403 1402 continue
1404 1403 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1405 1404 for info in self.services.values():
1406 1405 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
1407 1406 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0, info.priority, info.weight, info.port, info.server), 0)
1408 1407 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
1409 1408 if info.address:
1410 1409 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0)
1411 1410 self.send(out)
1412 1411 i += 1
1413 1412 nextTime += _UNREGISTER_TIME
1414 1413
1415 1414 def checkService(self, info):
1416 1415 """Checks the network for a unique service name, modifying the
1417 1416 ServiceInfo passed in if it is not unique."""
1418 1417 now = currentTimeMillis()
1419 1418 nextTime = now
1420 1419 i = 0
1421 1420 while i < 3:
1422 1421 for record in self.cache.entriesWithName(info.type):
1423 1422 if record.type == _TYPE_PTR and not record.isExpired(now) and record.alias == info.name:
1424 1423 if (info.name.find('.') < 0):
1425 1424 info.name = info.name + ".[" + info.address + ":" + info.port + "]." + info.type
1426 1425 self.checkService(info)
1427 1426 return
1428 1427 raise NonUniqueNameException
1429 1428 if now < nextTime:
1430 1429 self.wait(nextTime - now)
1431 1430 now = currentTimeMillis()
1432 1431 continue
1433 1432 out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
1434 1433 self.debug = out
1435 1434 out.addQuestion(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN))
1436 1435 out.addAuthoritativeAnswer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, _DNS_TTL, info.name))
1437 1436 self.send(out)
1438 1437 i += 1
1439 1438 nextTime += _CHECK_TIME
1440 1439
1441 1440 def addListener(self, listener, question):
1442 1441 """Adds a listener for a given question. The listener will have
1443 1442 its updateRecord method called when information is available to
1444 1443 answer the question."""
1445 1444 now = currentTimeMillis()
1446 1445 self.listeners.append(listener)
1447 1446 if question is not None:
1448 1447 for record in self.cache.entriesWithName(question.name):
1449 1448 if question.answeredBy(record) and not record.isExpired(now):
1450 1449 listener.updateRecord(self, now, record)
1451 1450 self.notifyAll()
1452 1451
1453 1452 def removeListener(self, listener):
1454 1453 """Removes a listener."""
1455 1454 try:
1456 1455 self.listeners.remove(listener)
1457 1456 self.notifyAll()
1458 1457 except Exception:
1459 1458 pass
1460 1459
1461 1460 def updateRecord(self, now, rec):
1462 1461 """Used to notify listeners of new information that has updated
1463 1462 a record."""
1464 1463 for listener in self.listeners:
1465 1464 listener.updateRecord(self, now, rec)
1466 1465 self.notifyAll()
1467 1466
1468 1467 def handleResponse(self, msg):
1469 1468 """Deal with incoming response packets. All answers
1470 1469 are held in the cache, and listeners are notified."""
1471 1470 now = currentTimeMillis()
1472 1471 for record in msg.answers:
1473 1472 expired = record.isExpired(now)
1474 1473 if record in self.cache.entries():
1475 1474 if expired:
1476 1475 self.cache.remove(record)
1477 1476 else:
1478 1477 entry = self.cache.get(record)
1479 1478 if entry is not None:
1480 1479 entry.resetTTL(record)
1481 1480 record = entry
1482 1481 else:
1483 1482 self.cache.add(record)
1484 1483
1485 1484 self.updateRecord(now, record)
1486 1485
1487 1486 def handleQuery(self, msg, addr, port):
1488 1487 """Deal with incoming query packets. Provides a response if
1489 1488 possible."""
1490 1489 out = None
1491 1490
1492 1491 # Support unicast client responses
1493 1492 #
1494 1493 if port != _MDNS_PORT:
1495 1494 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, 0)
1496 1495 for question in msg.questions:
1497 1496 out.addQuestion(question)
1498 1497
1499 1498 for question in msg.questions:
1500 1499 if question.type == _TYPE_PTR:
1501 1500 if question.name == "_services._dns-sd._udp.local.":
1502 1501 for stype in self.servicetypes.keys():
1503 1502 if out is None:
1504 1503 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1505 1504 out.addAnswer(msg, DNSPointer("_services._dns-sd._udp.local.", _TYPE_PTR, _CLASS_IN, _DNS_TTL, stype))
1506 1505 for service in self.services.values():
1507 1506 if question.name == service.type:
1508 1507 if out is None:
1509 1508 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1510 1509 out.addAnswer(msg, DNSPointer(service.type, _TYPE_PTR, _CLASS_IN, _DNS_TTL, service.name))
1511 1510 else:
1512 1511 try:
1513 1512 if out is None:
1514 1513 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1515 1514
1516 1515 # Answer A record queries for any service addresses we know
1517 1516 if question.type == _TYPE_A or question.type == _TYPE_ANY:
1518 1517 for service in self.services.values():
1519 1518 if service.server == question.name.lower():
1520 1519 out.addAnswer(msg, DNSAddress(question.name, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.address))
1521 1520
1522 1521 service = self.services.get(question.name.lower(), None)
1523 1522 if not service: continue
1524 1523
1525 1524 if question.type == _TYPE_SRV or question.type == _TYPE_ANY:
1526 1525 out.addAnswer(msg, DNSService(question.name, _TYPE_SRV, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.priority, service.weight, service.port, service.server))
1527 1526 if question.type == _TYPE_TXT or question.type == _TYPE_ANY:
1528 1527 out.addAnswer(msg, DNSText(question.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.text))
1529 1528 if question.type == _TYPE_SRV:
1530 1529 out.addAdditionalAnswer(DNSAddress(service.server, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.address))
1531 1530 except Exception:
1532 1531 traceback.print_exc()
1533 1532
1534 1533 if out is not None and out.answers:
1535 1534 out.id = msg.id
1536 1535 self.send(out, addr, port)
1537 1536
1538 1537 def send(self, out, addr = _MDNS_ADDR, port = _MDNS_PORT):
1539 1538 """Sends an outgoing packet."""
1540 1539 # This is a quick test to see if we can parse the packets we generate
1541 1540 #temp = DNSIncoming(out.packet())
1542 1541 try:
1543 1542 self.socket.sendto(out.packet(), 0, (addr, port))
1544 1543 except Exception:
1545 1544 # Ignore this, it may be a temporary loss of network connection
1546 1545 pass
1547 1546
1548 1547 def close(self):
1549 1548 """Ends the background threads, and prevent this instance from
1550 1549 servicing further queries."""
1551 1550 if globals()['_GLOBAL_DONE'] == 0:
1552 1551 globals()['_GLOBAL_DONE'] = 1
1553 1552 self.notifyAll()
1554 1553 self.engine.notify()
1555 1554 self.unregisterAllServices()
1556 1555 self.socket.setsockopt(socket.SOL_IP, socket.IP_DROP_MEMBERSHIP, socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0'))
1557 1556 self.socket.close()
1558 1557
1559 1558 # Test a few module features, including service registration, service
1560 1559 # query (for Zoe), and service unregistration.
1561 1560
1562 1561 if __name__ == '__main__':
1563 1562 print "Multicast DNS Service Discovery for Python, version", __version__
1564 1563 r = Zeroconf()
1565 1564 print "1. Testing registration of a service..."
1566 1565 desc = {'version':'0.10','a':'test value', 'b':'another value'}
1567 1566 info = ServiceInfo("_http._tcp.local.", "My Service Name._http._tcp.local.", socket.inet_aton("127.0.0.1"), 1234, 0, 0, desc)
1568 1567 print " Registering service..."
1569 1568 r.registerService(info)
1570 1569 print " Registration done."
1571 1570 print "2. Testing query of service information..."
1572 1571 print " Getting ZOE service:", str(r.getServiceInfo("_http._tcp.local.", "ZOE._http._tcp.local."))
1573 1572 print " Query done."
1574 1573 print "3. Testing query of own service..."
1575 1574 print " Getting self:", str(r.getServiceInfo("_http._tcp.local.", "My Service Name._http._tcp.local."))
1576 1575 print " Query done."
1577 1576 print "4. Testing unregister of service information..."
1578 1577 r.unregisterService(info)
1579 1578 print " Unregister done."
1580 1579 r.close()
1581 1580
1582 1581 # no-check-code
@@ -1,206 +1,204
1 1 # manifest.py - manifest revision class for mercurial
2 2 #
3 3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from i18n import _
9 9 import mdiff, parsers, error, revlog, util
10 10 import array, struct
11 11
12 12 class manifestdict(dict):
13 13 def __init__(self, mapping=None, flags=None):
14 14 if mapping is None:
15 15 mapping = {}
16 16 if flags is None:
17 17 flags = {}
18 18 dict.__init__(self, mapping)
19 19 self._flags = flags
20 20 def flags(self, f):
21 21 return self._flags.get(f, "")
22 22 def withflags(self):
23 23 return set(self._flags.keys())
24 24 def set(self, f, flags):
25 25 self._flags[f] = flags
26 26 def copy(self):
27 27 return manifestdict(self, dict.copy(self._flags))
28 28
29 29 class manifest(revlog.revlog):
30 30 def __init__(self, opener):
31 31 self._mancache = None
32 32 revlog.revlog.__init__(self, opener, "00manifest.i")
33 33
34 34 def parse(self, lines):
35 35 mfdict = manifestdict()
36 36 parsers.parse_manifest(mfdict, mfdict._flags, lines)
37 37 return mfdict
38 38
39 39 def readdelta(self, node):
40 40 r = self.rev(node)
41 41 return self.parse(mdiff.patchtext(self.revdiff(self.deltaparent(r), r)))
42 42
43 43 def readfast(self, node):
44 44 '''use the faster of readdelta or read'''
45 45 r = self.rev(node)
46 46 deltaparent = self.deltaparent(r)
47 47 if deltaparent != revlog.nullrev and deltaparent in self.parentrevs(r):
48 48 return self.readdelta(node)
49 49 return self.read(node)
50 50
51 51 def read(self, node):
52 52 if node == revlog.nullid:
53 53 return manifestdict() # don't upset local cache
54 54 if self._mancache and self._mancache[0] == node:
55 55 return self._mancache[1]
56 56 text = self.revision(node)
57 57 arraytext = array.array('c', text)
58 58 mapping = self.parse(text)
59 59 self._mancache = (node, mapping, arraytext)
60 60 return mapping
61 61
62 62 def _search(self, m, s, lo=0, hi=None):
63 63 '''return a tuple (start, end) that says where to find s within m.
64 64
65 65 If the string is found m[start:end] are the line containing
66 66 that string. If start == end the string was not found and
67 they indicate the proper sorted insertion point. This was
68 taken from bisect_left, and modified to find line start/end as
69 it goes along.
67 they indicate the proper sorted insertion point.
70 68
71 69 m should be a buffer or a string
72 70 s is a string'''
73 71 def advance(i, c):
74 72 while i < lenm and m[i] != c:
75 73 i += 1
76 74 return i
77 75 if not s:
78 76 return (lo, lo)
79 77 lenm = len(m)
80 78 if not hi:
81 79 hi = lenm
82 80 while lo < hi:
83 81 mid = (lo + hi) // 2
84 82 start = mid
85 83 while start > 0 and m[start - 1] != '\n':
86 84 start -= 1
87 85 end = advance(start, '\0')
88 86 if m[start:end] < s:
89 87 # we know that after the null there are 40 bytes of sha1
90 88 # this translates to the bisect lo = mid + 1
91 89 lo = advance(end + 40, '\n') + 1
92 90 else:
93 91 # this translates to the bisect hi = mid
94 92 hi = start
95 93 end = advance(lo, '\0')
96 94 found = m[lo:end]
97 95 if s == found:
98 96 # we know that after the null there are 40 bytes of sha1
99 97 end = advance(end + 40, '\n')
100 98 return (lo, end + 1)
101 99 else:
102 100 return (lo, lo)
103 101
104 102 def find(self, node, f):
105 103 '''look up entry for a single file efficiently.
106 104 return (node, flags) pair if found, (None, None) if not.'''
107 105 if self._mancache and self._mancache[0] == node:
108 106 return self._mancache[1].get(f), self._mancache[1].flags(f)
109 107 text = self.revision(node)
110 108 start, end = self._search(text, f)
111 109 if start == end:
112 110 return None, None
113 111 l = text[start:end]
114 112 f, n = l.split('\0')
115 113 return revlog.bin(n[:40]), n[40:-1]
116 114
117 115 def add(self, map, transaction, link, p1=None, p2=None,
118 116 changed=None):
119 117 # apply the changes collected during the bisect loop to our addlist
120 118 # return a delta suitable for addrevision
121 119 def addlistdelta(addlist, x):
122 120 # start from the bottom up
123 121 # so changes to the offsets don't mess things up.
124 122 for start, end, content in reversed(x):
125 123 if content:
126 124 addlist[start:end] = array.array('c', content)
127 125 else:
128 126 del addlist[start:end]
129 127 return "".join(struct.pack(">lll", start, end, len(content))
130 128 + content for start, end, content in x)
131 129
132 130 def checkforbidden(l):
133 131 for f in l:
134 132 if '\n' in f or '\r' in f:
135 133 raise error.RevlogError(
136 134 _("'\\n' and '\\r' disallowed in filenames: %r") % f)
137 135
138 136 # if we're using the cache, make sure it is valid and
139 137 # parented by the same node we're diffing against
140 138 if not (changed and self._mancache and p1 and self._mancache[0] == p1):
141 139 files = sorted(map)
142 140 checkforbidden(files)
143 141
144 142 # if this is changed to support newlines in filenames,
145 143 # be sure to check the templates/ dir again (especially *-raw.tmpl)
146 144 hex, flags = revlog.hex, map.flags
147 145 text = ''.join("%s\0%s%s\n" % (f, hex(map[f]), flags(f))
148 146 for f in files)
149 147 arraytext = array.array('c', text)
150 148 cachedelta = None
151 149 else:
152 150 added, removed = changed
153 151 addlist = self._mancache[2]
154 152
155 153 checkforbidden(added)
156 154 # combine the changed lists into one list for sorting
157 155 work = [(x, False) for x in added]
158 156 work.extend((x, True) for x in removed)
159 157 # this could use heapq.merge() (from python2.6+) or equivalent
160 158 # since the lists are already sorted
161 159 work.sort()
162 160
163 161 delta = []
164 162 dstart = None
165 163 dend = None
166 164 dline = [""]
167 165 start = 0
168 166 # zero copy representation of addlist as a buffer
169 167 addbuf = util.buffer(addlist)
170 168
171 169 # start with a readonly loop that finds the offset of
172 170 # each line and creates the deltas
173 171 for f, todelete in work:
174 172 # bs will either be the index of the item or the insert point
175 173 start, end = self._search(addbuf, f, start)
176 174 if not todelete:
177 175 l = "%s\0%s%s\n" % (f, revlog.hex(map[f]), map.flags(f))
178 176 else:
179 177 if start == end:
180 178 # item we want to delete was not found, error out
181 179 raise AssertionError(
182 180 _("failed to remove %s from manifest") % f)
183 181 l = ""
184 182 if dstart is not None and dstart <= start and dend >= start:
185 183 if dend < end:
186 184 dend = end
187 185 if l:
188 186 dline.append(l)
189 187 else:
190 188 if dstart is not None:
191 189 delta.append([dstart, dend, "".join(dline)])
192 190 dstart = start
193 191 dend = end
194 192 dline = [l]
195 193
196 194 if dstart is not None:
197 195 delta.append([dstart, dend, "".join(dline)])
198 196 # apply the delta to the addlist, and get a delta for addrevision
199 197 cachedelta = (self.rev(p1), addlistdelta(addlist, delta))
200 198 arraytext = addlist
201 199 text = util.buffer(arraytext)
202 200
203 201 n = self.addrevision(text, transaction, link, p1, p2, cachedelta)
204 202 self._mancache = (n, map, arraytext)
205 203
206 204 return n
@@ -1,331 +1,327
1 1 # obsolete.py - obsolete markers handling
2 2 #
3 3 # Copyright 2012 Pierre-Yves David <pierre-yves.david@ens-lyon.org>
4 4 # Logilab SA <contact@logilab.fr>
5 5 #
6 6 # This software may be used and distributed according to the terms of the
7 7 # GNU General Public License version 2 or any later version.
8 8
9 9 """Obsolete markers handling
10 10
11 11 An obsolete marker maps an old changeset to a list of new
12 12 changesets. If the list of new changesets is empty, the old changeset
13 13 is said to be "killed". Otherwise, the old changeset is being
14 14 "replaced" by the new changesets.
15 15
16 16 Obsolete markers can be used to record and distribute changeset graph
17 17 transformations performed by history rewriting operations, and help
18 18 building new tools to reconciliate conflicting rewriting actions. To
19 19 facilitate conflicts resolution, markers include various annotations
20 20 besides old and news changeset identifiers, such as creation date or
21 21 author name.
22 22
23 23
24 24 Format
25 25 ------
26 26
27 27 Markers are stored in an append-only file stored in
28 28 '.hg/store/obsstore'.
29 29
30 30 The file starts with a version header:
31 31
32 32 - 1 unsigned byte: version number, starting at zero.
33 33
34 34
35 35 The header is followed by the markers. Each marker is made of:
36 36
37 37 - 1 unsigned byte: number of new changesets "R", could be zero.
38 38
39 39 - 1 unsigned 32-bits integer: metadata size "M" in bytes.
40 40
41 41 - 1 byte: a bit field. It is reserved for flags used in obsolete
42 42 markers common operations, to avoid repeated decoding of metadata
43 43 entries.
44 44
45 45 - 20 bytes: obsoleted changeset identifier.
46 46
47 47 - N*20 bytes: new changesets identifiers.
48 48
49 49 - M bytes: metadata as a sequence of nul-terminated strings. Each
50 50 string contains a key and a value, separated by a color ':', without
51 51 additional encoding. Keys cannot contain '\0' or ':' and values
52 52 cannot contain '\0'.
53 53 """
54 54 import struct
55 55 import util, base85
56 56 from i18n import _
57 57
58 # the obsolete feature is not mature enought to be enabled by default.
59 # you have to rely on third party extension extension to enable this.
60 _enabled = False
61
62 58 _pack = struct.pack
63 59 _unpack = struct.unpack
64 60
65 61 # the obsolete feature is not mature enough to be enabled by default.
66 62 # you have to rely on third party extension extension to enable this.
67 63 _enabled = False
68 64
69 65 # data used for parsing and writing
70 66 _fmversion = 0
71 67 _fmfixed = '>BIB20s'
72 68 _fmnode = '20s'
73 69 _fmfsize = struct.calcsize(_fmfixed)
74 70 _fnodesize = struct.calcsize(_fmnode)
75 71
76 72 def _readmarkers(data):
77 73 """Read and enumerate markers from raw data"""
78 74 off = 0
79 75 diskversion = _unpack('>B', data[off:off + 1])[0]
80 76 off += 1
81 77 if diskversion != _fmversion:
82 78 raise util.Abort(_('parsing obsolete marker: unknown version %r')
83 79 % diskversion)
84 80
85 81 # Loop on markers
86 82 l = len(data)
87 83 while off + _fmfsize <= l:
88 84 # read fixed part
89 85 cur = data[off:off + _fmfsize]
90 86 off += _fmfsize
91 87 nbsuc, mdsize, flags, pre = _unpack(_fmfixed, cur)
92 88 # read replacement
93 89 sucs = ()
94 90 if nbsuc:
95 91 s = (_fnodesize * nbsuc)
96 92 cur = data[off:off + s]
97 93 sucs = _unpack(_fmnode * nbsuc, cur)
98 94 off += s
99 95 # read metadata
100 96 # (metadata will be decoded on demand)
101 97 metadata = data[off:off + mdsize]
102 98 if len(metadata) != mdsize:
103 99 raise util.Abort(_('parsing obsolete marker: metadata is too '
104 100 'short, %d bytes expected, got %d')
105 101 % (mdsize, len(metadata)))
106 102 off += mdsize
107 103 yield (pre, sucs, flags, metadata)
108 104
109 105 def encodemeta(meta):
110 106 """Return encoded metadata string to string mapping.
111 107
112 108 Assume no ':' in key and no '\0' in both key and value."""
113 109 for key, value in meta.iteritems():
114 110 if ':' in key or '\0' in key:
115 111 raise ValueError("':' and '\0' are forbidden in metadata key'")
116 112 if '\0' in value:
117 113 raise ValueError("':' are forbidden in metadata value'")
118 114 return '\0'.join(['%s:%s' % (k, meta[k]) for k in sorted(meta)])
119 115
120 116 def decodemeta(data):
121 117 """Return string to string dictionary from encoded version."""
122 118 d = {}
123 119 for l in data.split('\0'):
124 120 if l:
125 121 key, value = l.split(':')
126 122 d[key] = value
127 123 return d
128 124
129 125 class marker(object):
130 126 """Wrap obsolete marker raw data"""
131 127
132 128 def __init__(self, repo, data):
133 129 # the repo argument will be used to create changectx in later version
134 130 self._repo = repo
135 131 self._data = data
136 132 self._decodedmeta = None
137 133
138 134 def precnode(self):
139 135 """Precursor changeset node identifier"""
140 136 return self._data[0]
141 137
142 138 def succnodes(self):
143 139 """List of successor changesets node identifiers"""
144 140 return self._data[1]
145 141
146 142 def metadata(self):
147 143 """Decoded metadata dictionary"""
148 144 if self._decodedmeta is None:
149 145 self._decodedmeta = decodemeta(self._data[3])
150 146 return self._decodedmeta
151 147
152 148 def date(self):
153 149 """Creation date as (unixtime, offset)"""
154 150 parts = self.metadata()['date'].split(' ')
155 151 return (float(parts[0]), int(parts[1]))
156 152
157 153 class obsstore(object):
158 154 """Store obsolete markers
159 155
160 156 Markers can be accessed with two mappings:
161 157 - precursors: old -> set(new)
162 158 - successors: new -> set(old)
163 159 """
164 160
165 161 def __init__(self, sopener):
166 162 self._all = []
167 163 # new markers to serialize
168 164 self.precursors = {}
169 165 self.successors = {}
170 166 self.sopener = sopener
171 167 data = sopener.tryread('obsstore')
172 168 if data:
173 169 self._load(_readmarkers(data))
174 170
175 171 def __iter__(self):
176 172 return iter(self._all)
177 173
178 174 def __nonzero__(self):
179 175 return bool(self._all)
180 176
181 177 def create(self, transaction, prec, succs=(), flag=0, metadata=None):
182 178 """obsolete: add a new obsolete marker
183 179
184 180 * ensuring it is hashable
185 181 * check mandatory metadata
186 182 * encode metadata
187 183 """
188 184 if metadata is None:
189 185 metadata = {}
190 186 if len(prec) != 20:
191 187 raise ValueError(prec)
192 188 for succ in succs:
193 189 if len(succ) != 20:
194 190 raise ValueError(succ)
195 191 marker = (str(prec), tuple(succs), int(flag), encodemeta(metadata))
196 192 self.add(transaction, [marker])
197 193
198 194 def add(self, transaction, markers):
199 195 """Add new markers to the store
200 196
201 197 Take care of filtering duplicate.
202 198 Return the number of new marker."""
203 199 if not _enabled:
204 200 raise util.Abort('obsolete feature is not enabled on this repo')
205 201 new = [m for m in markers if m not in self._all]
206 202 if new:
207 203 f = self.sopener('obsstore', 'ab')
208 204 try:
209 205 # Whether the file's current position is at the begin or at
210 206 # the end after opening a file for appending is implementation
211 207 # defined. So we must seek to the end before calling tell(),
212 208 # or we may get a zero offset for non-zero sized files on
213 209 # some platforms (issue3543).
214 210 f.seek(0, 2) # os.SEEK_END
215 211 offset = f.tell()
216 212 transaction.add('obsstore', offset)
217 213 # offset == 0: new file - add the version header
218 214 for bytes in _encodemarkers(new, offset == 0):
219 215 f.write(bytes)
220 216 finally:
221 217 # XXX: f.close() == filecache invalidation == obsstore rebuilt.
222 218 # call 'filecacheentry.refresh()' here
223 219 f.close()
224 220 self._load(new)
225 221 return len(new)
226 222
227 223 def mergemarkers(self, transation, data):
228 224 markers = _readmarkers(data)
229 225 self.add(transation, markers)
230 226
231 227 def _load(self, markers):
232 228 for mark in markers:
233 229 self._all.append(mark)
234 230 pre, sucs = mark[:2]
235 231 self.precursors.setdefault(pre, set()).add(mark)
236 232 for suc in sucs:
237 233 self.successors.setdefault(suc, set()).add(mark)
238 234
239 235 def _encodemarkers(markers, addheader=False):
240 236 # Kept separate from flushmarkers(), it will be reused for
241 237 # markers exchange.
242 238 if addheader:
243 239 yield _pack('>B', _fmversion)
244 240 for marker in markers:
245 241 yield _encodeonemarker(marker)
246 242
247 243
248 244 def _encodeonemarker(marker):
249 245 pre, sucs, flags, metadata = marker
250 246 nbsuc = len(sucs)
251 247 format = _fmfixed + (_fmnode * nbsuc)
252 248 data = [nbsuc, len(metadata), flags, pre]
253 249 data.extend(sucs)
254 250 return _pack(format, *data) + metadata
255 251
256 252 # arbitrary picked to fit into 8K limit from HTTP server
257 253 # you have to take in account:
258 254 # - the version header
259 255 # - the base85 encoding
260 256 _maxpayload = 5300
261 257
262 258 def listmarkers(repo):
263 259 """List markers over pushkey"""
264 260 if not repo.obsstore:
265 261 return {}
266 262 keys = {}
267 263 parts = []
268 264 currentlen = _maxpayload * 2 # ensure we create a new part
269 265 for marker in repo.obsstore:
270 266 nextdata = _encodeonemarker(marker)
271 267 if (len(nextdata) + currentlen > _maxpayload):
272 268 currentpart = []
273 269 currentlen = 0
274 270 parts.append(currentpart)
275 271 currentpart.append(nextdata)
276 272 currentlen += len(nextdata)
277 273 for idx, part in enumerate(reversed(parts)):
278 274 data = ''.join([_pack('>B', _fmversion)] + part)
279 275 keys['dump%i' % idx] = base85.b85encode(data)
280 276 return keys
281 277
282 278 def pushmarker(repo, key, old, new):
283 279 """Push markers over pushkey"""
284 280 if not key.startswith('dump'):
285 281 repo.ui.warn(_('unknown key: %r') % key)
286 282 return 0
287 283 if old:
288 284 repo.ui.warn(_('unexpected old value') % key)
289 285 return 0
290 286 data = base85.b85decode(new)
291 287 lock = repo.lock()
292 288 try:
293 289 tr = repo.transaction('pushkey: obsolete markers')
294 290 try:
295 291 repo.obsstore.mergemarkers(tr, data)
296 292 tr.close()
297 293 return 1
298 294 finally:
299 295 tr.release()
300 296 finally:
301 297 lock.release()
302 298
303 299 def allmarkers(repo):
304 300 """all obsolete markers known in a repository"""
305 301 for markerdata in repo.obsstore:
306 302 yield marker(repo, markerdata)
307 303
308 304 def precursormarkers(ctx):
309 305 """obsolete marker making this changeset obsolete"""
310 306 for data in ctx._repo.obsstore.precursors.get(ctx.node(), ()):
311 307 yield marker(ctx._repo, data)
312 308
313 309 def successormarkers(ctx):
314 310 """obsolete marker marking this changeset as a successors"""
315 311 for data in ctx._repo.obsstore.successors.get(ctx.node(), ()):
316 312 yield marker(ctx._repo, data)
317 313
318 314 def anysuccessors(obsstore, node):
319 315 """Yield every successor of <node>
320 316
321 317 This is a linear yield unsuitable to detect split changesets."""
322 318 remaining = set([node])
323 319 seen = set(remaining)
324 320 while remaining:
325 321 current = remaining.pop()
326 322 yield current
327 323 for mark in obsstore.precursors.get(current, ()):
328 324 for suc in mark[1]:
329 325 if suc not in seen:
330 326 seen.add(suc)
331 327 remaining.add(suc)
@@ -1,200 +1,197
1 1 # setdiscovery.py - improved discovery of common nodeset for mercurial
2 2 #
3 3 # Copyright 2010 Benoit Boissinot <bboissin@gmail.com>
4 4 # and Peter Arrenbrecht <peter@arrenbrecht.ch>
5 5 #
6 6 # This software may be used and distributed according to the terms of the
7 7 # GNU General Public License version 2 or any later version.
8 8
9 9 from node import nullid
10 10 from i18n import _
11 11 import random, util, dagutil
12 12
13 13 def _updatesample(dag, nodes, sample, always, quicksamplesize=0):
14 14 # if nodes is empty we scan the entire graph
15 15 if nodes:
16 16 heads = dag.headsetofconnecteds(nodes)
17 17 else:
18 18 heads = dag.heads()
19 19 dist = {}
20 20 visit = util.deque(heads)
21 21 seen = set()
22 22 factor = 1
23 23 while visit:
24 24 curr = visit.popleft()
25 25 if curr in seen:
26 26 continue
27 27 d = dist.setdefault(curr, 1)
28 28 if d > factor:
29 29 factor *= 2
30 30 if d == factor:
31 31 if curr not in always: # need this check for the early exit below
32 32 sample.add(curr)
33 33 if quicksamplesize and (len(sample) >= quicksamplesize):
34 34 return
35 35 seen.add(curr)
36 36 for p in dag.parents(curr):
37 37 if not nodes or p in nodes:
38 38 dist.setdefault(p, d + 1)
39 39 visit.append(p)
40 40
41 41 def _setupsample(dag, nodes, size):
42 42 if len(nodes) <= size:
43 43 return set(nodes), None, 0
44 44 always = dag.headsetofconnecteds(nodes)
45 45 desiredlen = size - len(always)
46 46 if desiredlen <= 0:
47 47 # This could be bad if there are very many heads, all unknown to the
48 48 # server. We're counting on long request support here.
49 49 return always, None, desiredlen
50 50 return always, set(), desiredlen
51 51
52 52 def _takequicksample(dag, nodes, size, initial):
53 53 always, sample, desiredlen = _setupsample(dag, nodes, size)
54 54 if sample is None:
55 55 return always
56 56 if initial:
57 57 fromset = None
58 58 else:
59 59 fromset = nodes
60 60 _updatesample(dag, fromset, sample, always, quicksamplesize=desiredlen)
61 61 sample.update(always)
62 62 return sample
63 63
64 64 def _takefullsample(dag, nodes, size):
65 65 always, sample, desiredlen = _setupsample(dag, nodes, size)
66 66 if sample is None:
67 67 return always
68 68 # update from heads
69 69 _updatesample(dag, nodes, sample, always)
70 70 # update from roots
71 71 _updatesample(dag.inverse(), nodes, sample, always)
72 72 assert sample
73 73 if len(sample) > desiredlen:
74 74 sample = set(random.sample(sample, desiredlen))
75 75 elif len(sample) < desiredlen:
76 76 more = desiredlen - len(sample)
77 77 sample.update(random.sample(list(nodes - sample - always), more))
78 78 sample.update(always)
79 79 return sample
80 80
81 81 def findcommonheads(ui, local, remote,
82 82 initialsamplesize=100,
83 83 fullsamplesize=200,
84 84 abortwhenunrelated=True):
85 85 '''Return a tuple (common, anyincoming, remoteheads) used to identify
86 86 missing nodes from or in remote.
87
88 shortcutlocal determines whether we try use direct access to localrepo if
89 remote is actually local.
90 87 '''
91 88 roundtrips = 0
92 89 cl = local.changelog
93 90 dag = dagutil.revlogdag(cl)
94 91
95 92 # early exit if we know all the specified remote heads already
96 93 ui.debug("query 1; heads\n")
97 94 roundtrips += 1
98 95 ownheads = dag.heads()
99 96 sample = ownheads
100 97 if remote.local():
101 98 # stopgap until we have a proper localpeer that supports batch()
102 99 srvheadhashes = remote.heads()
103 100 yesno = remote.known(dag.externalizeall(sample))
104 101 elif remote.capable('batch'):
105 102 batch = remote.batch()
106 103 srvheadhashesref = batch.heads()
107 104 yesnoref = batch.known(dag.externalizeall(sample))
108 105 batch.submit()
109 106 srvheadhashes = srvheadhashesref.value
110 107 yesno = yesnoref.value
111 108 else:
112 109 # compatibility with pre-batch, but post-known remotes during 1.9
113 110 # development
114 111 srvheadhashes = remote.heads()
115 112 sample = []
116 113
117 114 if cl.tip() == nullid:
118 115 if srvheadhashes != [nullid]:
119 116 return [nullid], True, srvheadhashes
120 117 return [nullid], False, []
121 118
122 119 # start actual discovery (we note this before the next "if" for
123 120 # compatibility reasons)
124 121 ui.status(_("searching for changes\n"))
125 122
126 123 srvheads = dag.internalizeall(srvheadhashes, filterunknown=True)
127 124 if len(srvheads) == len(srvheadhashes):
128 125 ui.debug("all remote heads known locally\n")
129 126 return (srvheadhashes, False, srvheadhashes,)
130 127
131 128 if sample and util.all(yesno):
132 129 ui.note(_("all local heads known remotely\n"))
133 130 ownheadhashes = dag.externalizeall(ownheads)
134 131 return (ownheadhashes, True, srvheadhashes,)
135 132
136 133 # full blown discovery
137 134
138 135 # own nodes where I don't know if remote knows them
139 136 undecided = dag.nodeset()
140 137 # own nodes I know we both know
141 138 common = set()
142 139 # own nodes I know remote lacks
143 140 missing = set()
144 141
145 142 # treat remote heads (and maybe own heads) as a first implicit sample
146 143 # response
147 144 common.update(dag.ancestorset(srvheads))
148 145 undecided.difference_update(common)
149 146
150 147 full = False
151 148 while undecided:
152 149
153 150 if sample:
154 151 commoninsample = set(n for i, n in enumerate(sample) if yesno[i])
155 152 common.update(dag.ancestorset(commoninsample, common))
156 153
157 154 missinginsample = [n for i, n in enumerate(sample) if not yesno[i]]
158 155 missing.update(dag.descendantset(missinginsample, missing))
159 156
160 157 undecided.difference_update(missing)
161 158 undecided.difference_update(common)
162 159
163 160 if not undecided:
164 161 break
165 162
166 163 if full:
167 164 ui.note(_("sampling from both directions\n"))
168 165 sample = _takefullsample(dag, undecided, size=fullsamplesize)
169 166 elif common:
170 167 # use cheapish initial sample
171 168 ui.debug("taking initial sample\n")
172 169 sample = _takefullsample(dag, undecided, size=fullsamplesize)
173 170 else:
174 171 # use even cheaper initial sample
175 172 ui.debug("taking quick initial sample\n")
176 173 sample = _takequicksample(dag, undecided, size=initialsamplesize,
177 174 initial=True)
178 175
179 176 roundtrips += 1
180 177 ui.progress(_('searching'), roundtrips, unit=_('queries'))
181 178 ui.debug("query %i; still undecided: %i, sample size is: %i\n"
182 179 % (roundtrips, len(undecided), len(sample)))
183 180 # indices between sample and externalized version must match
184 181 sample = list(sample)
185 182 yesno = remote.known(dag.externalizeall(sample))
186 183 full = True
187 184
188 185 result = dag.headsetofconnecteds(common)
189 186 ui.progress(_('searching'), None)
190 187 ui.debug("%d total queries\n" % roundtrips)
191 188
192 189 if not result and srvheadhashes != [nullid]:
193 190 if abortwhenunrelated:
194 191 raise util.Abort(_("repository is unrelated"))
195 192 else:
196 193 ui.warn(_("warning: repository is unrelated\n"))
197 194 return (set([nullid]), True, srvheadhashes,)
198 195
199 196 anyincoming = (srvheadhashes != [nullid])
200 197 return dag.externalizeall(result), anyincoming, srvheadhashes
General Comments 0
You need to be logged in to leave comments. Login now