##// END OF EJS Templates
zeroconf: use str instead of bytes when indexing `globals()`...
Matt Harbison -
r52893:c76c1c94 default
parent child Browse files
Show More
@@ -1,1903 +1,1903
1 1 from __future__ import annotations
2 2
3 3 """ Multicast DNS Service Discovery for Python, v0.12
4 4 Copyright (C) 2003, Paul Scott-Murphy
5 5
6 6 This module provides a framework for the use of DNS Service Discovery
7 7 using IP multicast. It has been tested against the JRendezvous
8 8 implementation from <a href="http://strangeberry.com">StrangeBerry</a>,
9 9 and against the mDNSResponder from Mac OS X 10.3.8.
10 10
11 11 This library is free software; you can redistribute it and/or
12 12 modify it under the terms of the GNU Lesser General Public
13 13 License as published by the Free Software Foundation; either
14 14 version 2.1 of the License, or (at your option) any later version.
15 15
16 16 This library is distributed in the hope that it will be useful,
17 17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 19 Lesser General Public License for more details.
20 20
21 21 You should have received a copy of the GNU Lesser General Public
22 22 License along with this library; if not, see
23 23 <http://www.gnu.org/licenses/>.
24 24
25 25 """
26 26
27 27 """0.12 update - allow selection of binding interface
28 28 typo fix - Thanks A. M. Kuchlingi
29 29 removed all use of word 'Rendezvous' - this is an API change"""
30 30
31 31 """0.11 update - correction to comments for addListener method
32 32 support for new record types seen from OS X
33 33 - IPv6 address
34 34 - hostinfo
35 35 ignore unknown DNS record types
36 36 fixes to name decoding
37 37 works alongside other processes using port 5353 (e.g. Mac OS X)
38 38 tested against Mac OS X 10.3.2's mDNSResponder
39 39 corrections to removal of list entries for service browser"""
40 40
41 41 """0.10 update - Jonathon Paisley contributed these corrections:
42 42 always multicast replies, even when query is unicast
43 43 correct a pointer encoding problem
44 44 can now write records in any order
45 45 traceback shown on failure
46 46 better TXT record parsing
47 47 server is now separate from name
48 48 can cancel a service browser
49 49
50 50 modified some unit tests to accommodate these changes"""
51 51
52 52 """0.09 update - remove all records on service unregistration
53 53 fix DOS security problem with readName"""
54 54
55 55 """0.08 update - changed licensing to LGPL"""
56 56
57 57 """0.07 update - faster shutdown on engine
58 58 pointer encoding of outgoing names
59 59 ServiceBrowser now works
60 60 new unit tests"""
61 61
62 62 """0.06 update - small improvements with unit tests
63 63 added defined exception types
64 64 new style objects
65 65 fixed hostname/interface problem
66 66 fixed socket timeout problem
67 67 fixed addServiceListener() typo bug
68 68 using select() for socket reads
69 69 tested on Debian unstable with Python 2.2.2"""
70 70
71 71 """0.05 update - ensure case insensitivity on domain names
72 72 support for unicast DNS queries"""
73 73
74 74 """0.04 update - added some unit tests
75 75 added __ne__ adjuncts where required
76 76 ensure names end in '.local.'
77 77 timeout on receiving socket for clean shutdown"""
78 78
79 79 __author__ = b"Paul Scott-Murphy"
80 80 __email__ = b"paul at scott dash murphy dot com"
81 81 __version__ = b"0.12"
82 82
83 83 import errno
84 84 import itertools
85 85 import select
86 86 import socket
87 87 import struct
88 88 import threading
89 89 import time
90 90 import traceback
91 91
92 92 from mercurial import pycompat
93 93
94 94 __all__ = ["Zeroconf", "ServiceInfo", "ServiceBrowser"]
95 95
96 96 # hook for threads
97 97
98 globals()[b'_GLOBAL_DONE'] = 0
98 globals()['_GLOBAL_DONE'] = 0
99 99
100 100 # Some timing constants
101 101
102 102 _UNREGISTER_TIME = 125
103 103 _CHECK_TIME = 175
104 104 _REGISTER_TIME = 225
105 105 _LISTENER_TIME = 200
106 106 _BROWSER_TIME = 500
107 107
108 108 # Some DNS constants
109 109
110 110 _MDNS_ADDR = r'224.0.0.251'
111 111 _MDNS_PORT = 5353
112 112 _DNS_PORT = 53
113 113 _DNS_TTL = 60 * 60 # one hour default TTL
114 114
115 115 _MAX_MSG_TYPICAL = 1460 # unused
116 116 _MAX_MSG_ABSOLUTE = 8972
117 117
118 118 _FLAGS_QR_MASK = 0x8000 # query response mask
119 119 _FLAGS_QR_QUERY = 0x0000 # query
120 120 _FLAGS_QR_RESPONSE = 0x8000 # response
121 121
122 122 _FLAGS_AA = 0x0400 # Authoritative answer
123 123 _FLAGS_TC = 0x0200 # Truncated
124 124 _FLAGS_RD = 0x0100 # Recursion desired
125 125 _FLAGS_RA = 0x8000 # Recursion available
126 126
127 127 _FLAGS_Z = 0x0040 # Zero
128 128 _FLAGS_AD = 0x0020 # Authentic data
129 129 _FLAGS_CD = 0x0010 # Checking disabled
130 130
131 131 _CLASS_IN = 1
132 132 _CLASS_CS = 2
133 133 _CLASS_CH = 3
134 134 _CLASS_HS = 4
135 135 _CLASS_NONE = 254
136 136 _CLASS_ANY = 255
137 137 _CLASS_MASK = 0x7FFF
138 138 _CLASS_UNIQUE = 0x8000
139 139
140 140 _TYPE_A = 1
141 141 _TYPE_NS = 2
142 142 _TYPE_MD = 3
143 143 _TYPE_MF = 4
144 144 _TYPE_CNAME = 5
145 145 _TYPE_SOA = 6
146 146 _TYPE_MB = 7
147 147 _TYPE_MG = 8
148 148 _TYPE_MR = 9
149 149 _TYPE_NULL = 10
150 150 _TYPE_WKS = 11
151 151 _TYPE_PTR = 12
152 152 _TYPE_HINFO = 13
153 153 _TYPE_MINFO = 14
154 154 _TYPE_MX = 15
155 155 _TYPE_TXT = 16
156 156 _TYPE_AAAA = 28
157 157 _TYPE_SRV = 33
158 158 _TYPE_ANY = 255
159 159
160 160 # Mapping constants to names
161 161
162 162 _CLASSES = {
163 163 _CLASS_IN: b"in",
164 164 _CLASS_CS: b"cs",
165 165 _CLASS_CH: b"ch",
166 166 _CLASS_HS: b"hs",
167 167 _CLASS_NONE: b"none",
168 168 _CLASS_ANY: b"any",
169 169 }
170 170
171 171 _TYPES = {
172 172 _TYPE_A: b"a",
173 173 _TYPE_NS: b"ns",
174 174 _TYPE_MD: b"md",
175 175 _TYPE_MF: b"mf",
176 176 _TYPE_CNAME: b"cname",
177 177 _TYPE_SOA: b"soa",
178 178 _TYPE_MB: b"mb",
179 179 _TYPE_MG: b"mg",
180 180 _TYPE_MR: b"mr",
181 181 _TYPE_NULL: b"null",
182 182 _TYPE_WKS: b"wks",
183 183 _TYPE_PTR: b"ptr",
184 184 _TYPE_HINFO: b"hinfo",
185 185 _TYPE_MINFO: b"minfo",
186 186 _TYPE_MX: b"mx",
187 187 _TYPE_TXT: b"txt",
188 188 _TYPE_AAAA: b"quada",
189 189 _TYPE_SRV: b"srv",
190 190 _TYPE_ANY: b"any",
191 191 }
192 192
193 193 # utility functions
194 194
195 195
196 196 def currentTimeMillis():
197 197 """Current system time in milliseconds"""
198 198 return time.time() * 1000
199 199
200 200
201 201 # Exceptions
202 202
203 203
204 204 class NonLocalNameException(Exception):
205 205 pass
206 206
207 207
208 208 class NonUniqueNameException(Exception):
209 209 pass
210 210
211 211
212 212 class NamePartTooLongException(Exception):
213 213 pass
214 214
215 215
216 216 class AbstractMethodException(Exception):
217 217 pass
218 218
219 219
220 220 class BadTypeInNameException(Exception):
221 221 pass
222 222
223 223
224 224 class BadDomainName(Exception):
225 225 def __init__(self, pos):
226 226 Exception.__init__(self, b"at position %s" % pos)
227 227
228 228
229 229 class BadDomainNameCircular(BadDomainName):
230 230 pass
231 231
232 232
233 233 # implementation classes
234 234
235 235 _SOL_IP = socket.SOL_IP
236 236
237 237 if pycompat.iswindows:
238 238 # XXX: Not sure if there are newer versions of python where this would fail,
239 239 # but apparently socket.SOL_IP used to be 0, and socket.IPPROTO_IP is 0, so
240 240 # this would work with older versions of python.
241 241 #
242 242 # https://github.com/python/cpython/issues/101960
243 243 _SOL_IP = socket.IPPROTO_IP
244 244
245 245
246 246 class DNSEntry:
247 247 """A DNS entry"""
248 248
249 249 def __init__(self, name, type, clazz):
250 250 self.key = name.lower()
251 251 self.name = name
252 252 self.type = type
253 253 self.clazz = clazz & _CLASS_MASK
254 254 self.unique = (clazz & _CLASS_UNIQUE) != 0
255 255
256 256 def __eq__(self, other):
257 257 """Equality test on name, type, and class"""
258 258 if isinstance(other, DNSEntry):
259 259 return (
260 260 self.name == other.name
261 261 and self.type == other.type
262 262 and self.clazz == other.clazz
263 263 )
264 264 return 0
265 265
266 266 def __ne__(self, other):
267 267 """Non-equality test"""
268 268 return not self.__eq__(other)
269 269
270 270 def getClazz(self, clazz):
271 271 """Class accessor"""
272 272 try:
273 273 return _CLASSES[clazz]
274 274 except KeyError:
275 275 return b"?(%s)" % clazz
276 276
277 277 def getType(self, type):
278 278 """Type accessor"""
279 279 try:
280 280 return _TYPES[type]
281 281 except KeyError:
282 282 return b"?(%s)" % type
283 283
284 284 def toString(self, hdr, other):
285 285 """String representation with additional information"""
286 286 result = b"%s[%s,%s" % (
287 287 hdr,
288 288 self.getType(self.type),
289 289 self.getClazz(self.clazz),
290 290 )
291 291 if self.unique:
292 292 result += b"-unique,"
293 293 else:
294 294 result += b","
295 295 result += self.name
296 296 if other is not None:
297 297 result += b",%s]" % other
298 298 else:
299 299 result += b"]"
300 300 return result
301 301
302 302
303 303 class DNSQuestion(DNSEntry):
304 304 """A DNS question entry"""
305 305
306 306 def __init__(self, name, type, clazz):
307 307 if isinstance(name, str):
308 308 name = name.encode('ascii')
309 309 if not name.endswith(b".local."):
310 310 raise NonLocalNameException(name)
311 311 DNSEntry.__init__(self, name, type, clazz)
312 312
313 313 def answeredBy(self, rec):
314 314 """Returns true if the question is answered by the record"""
315 315 return (
316 316 self.clazz == rec.clazz
317 317 and (self.type == rec.type or self.type == _TYPE_ANY)
318 318 and self.name == rec.name
319 319 )
320 320
321 321 def __repr__(self):
322 322 """String representation"""
323 323 return DNSEntry.toString(self, b"question", None)
324 324
325 325
326 326 class DNSRecord(DNSEntry):
327 327 """A DNS record - like a DNS entry, but has a TTL"""
328 328
329 329 def __init__(self, name, type, clazz, ttl):
330 330 DNSEntry.__init__(self, name, type, clazz)
331 331 self.ttl = ttl
332 332 self.created = currentTimeMillis()
333 333
334 334 def __eq__(self, other):
335 335 """Tests equality as per DNSRecord"""
336 336 if isinstance(other, DNSRecord):
337 337 return DNSEntry.__eq__(self, other)
338 338 return 0
339 339
340 340 def suppressedBy(self, msg):
341 341 """Returns true if any answer in a message can suffice for the
342 342 information held in this record."""
343 343 for record in msg.answers:
344 344 if self.suppressedByAnswer(record):
345 345 return 1
346 346 return 0
347 347
348 348 def suppressedByAnswer(self, other):
349 349 """Returns true if another record has same name, type and class,
350 350 and if its TTL is at least half of this record's."""
351 351 if self == other and other.ttl > (self.ttl / 2):
352 352 return 1
353 353 return 0
354 354
355 355 def getExpirationTime(self, percent):
356 356 """Returns the time at which this record will have expired
357 357 by a certain percentage."""
358 358 return self.created + (percent * self.ttl * 10)
359 359
360 360 def getRemainingTTL(self, now):
361 361 """Returns the remaining TTL in seconds."""
362 362 return max(0, (self.getExpirationTime(100) - now) / 1000)
363 363
364 364 def isExpired(self, now):
365 365 """Returns true if this record has expired."""
366 366 return self.getExpirationTime(100) <= now
367 367
368 368 def isStale(self, now):
369 369 """Returns true if this record is at least half way expired."""
370 370 return self.getExpirationTime(50) <= now
371 371
372 372 def resetTTL(self, other):
373 373 """Sets this record's TTL and created time to that of
374 374 another record."""
375 375 self.created = other.created
376 376 self.ttl = other.ttl
377 377
378 378 def write(self, out):
379 379 """Abstract method"""
380 380 raise AbstractMethodException
381 381
382 382 def toString(self, other):
383 383 """String representation with additional information"""
384 384 arg = b"%s/%s,%s" % (
385 385 self.ttl,
386 386 self.getRemainingTTL(currentTimeMillis()),
387 387 other,
388 388 )
389 389 return DNSEntry.toString(self, b"record", arg)
390 390
391 391
392 392 class DNSAddress(DNSRecord):
393 393 """A DNS address record"""
394 394
395 395 def __init__(self, name, type, clazz, ttl, address):
396 396 DNSRecord.__init__(self, name, type, clazz, ttl)
397 397 self.address = address
398 398
399 399 def write(self, out):
400 400 """Used in constructing an outgoing packet"""
401 401 out.writeString(self.address, len(self.address))
402 402
403 403 def __eq__(self, other):
404 404 """Tests equality on address"""
405 405 if isinstance(other, DNSAddress):
406 406 return self.address == other.address
407 407 return 0
408 408
409 409 def __repr__(self):
410 410 """String representation"""
411 411 try:
412 412 return socket.inet_ntoa(self.address)
413 413 except Exception:
414 414 return self.address
415 415
416 416
417 417 class DNSHinfo(DNSRecord):
418 418 """A DNS host information record"""
419 419
420 420 def __init__(self, name, type, clazz, ttl, cpu, os):
421 421 DNSRecord.__init__(self, name, type, clazz, ttl)
422 422 self.cpu = cpu
423 423 self.os = os
424 424
425 425 def write(self, out):
426 426 """Used in constructing an outgoing packet"""
427 427 out.writeString(self.cpu, len(self.cpu))
428 428 out.writeString(self.os, len(self.os))
429 429
430 430 def __eq__(self, other):
431 431 """Tests equality on cpu and os"""
432 432 if isinstance(other, DNSHinfo):
433 433 return self.cpu == other.cpu and self.os == other.os
434 434 return 0
435 435
436 436 def __repr__(self):
437 437 """String representation"""
438 438 return self.cpu + b" " + self.os
439 439
440 440
441 441 class DNSPointer(DNSRecord):
442 442 """A DNS pointer record"""
443 443
444 444 def __init__(self, name, type, clazz, ttl, alias):
445 445 DNSRecord.__init__(self, name, type, clazz, ttl)
446 446 self.alias = alias
447 447
448 448 def write(self, out):
449 449 """Used in constructing an outgoing packet"""
450 450 out.writeName(self.alias)
451 451
452 452 def __eq__(self, other):
453 453 """Tests equality on alias"""
454 454 if isinstance(other, DNSPointer):
455 455 return self.alias == other.alias
456 456 return 0
457 457
458 458 def __repr__(self):
459 459 """String representation"""
460 460 return self.toString(self.alias)
461 461
462 462
463 463 class DNSText(DNSRecord):
464 464 """A DNS text record"""
465 465
466 466 def __init__(self, name, type, clazz, ttl, text):
467 467 DNSRecord.__init__(self, name, type, clazz, ttl)
468 468 self.text = text
469 469
470 470 def write(self, out):
471 471 """Used in constructing an outgoing packet"""
472 472 out.writeString(self.text, len(self.text))
473 473
474 474 def __eq__(self, other):
475 475 """Tests equality on text"""
476 476 if isinstance(other, DNSText):
477 477 return self.text == other.text
478 478 return 0
479 479
480 480 def __repr__(self):
481 481 """String representation"""
482 482 if len(self.text) > 10:
483 483 return self.toString(self.text[:7] + b"...")
484 484 else:
485 485 return self.toString(self.text)
486 486
487 487
488 488 class DNSService(DNSRecord):
489 489 """A DNS service record"""
490 490
491 491 def __init__(self, name, type, clazz, ttl, priority, weight, port, server):
492 492 DNSRecord.__init__(self, name, type, clazz, ttl)
493 493 self.priority = priority
494 494 self.weight = weight
495 495 self.port = port
496 496 self.server = server
497 497
498 498 def write(self, out):
499 499 """Used in constructing an outgoing packet"""
500 500 out.writeShort(self.priority)
501 501 out.writeShort(self.weight)
502 502 out.writeShort(self.port)
503 503 out.writeName(self.server)
504 504
505 505 def __eq__(self, other):
506 506 """Tests equality on priority, weight, port and server"""
507 507 if isinstance(other, DNSService):
508 508 return (
509 509 self.priority == other.priority
510 510 and self.weight == other.weight
511 511 and self.port == other.port
512 512 and self.server == other.server
513 513 )
514 514 return 0
515 515
516 516 def __repr__(self):
517 517 """String representation"""
518 518 return self.toString(b"%s:%s" % (self.server, self.port))
519 519
520 520
521 521 class DNSIncoming:
522 522 """Object representation of an incoming DNS packet"""
523 523
524 524 def __init__(self, data):
525 525 """Constructor from string holding bytes of packet"""
526 526 self.offset = 0
527 527 self.data = data
528 528 self.questions = []
529 529 self.answers = []
530 530 self.numquestions = 0
531 531 self.numanswers = 0
532 532 self.numauthorities = 0
533 533 self.numadditionals = 0
534 534
535 535 self.readHeader()
536 536 self.readQuestions()
537 537 self.readOthers()
538 538
539 539 def readHeader(self):
540 540 """Reads header portion of packet"""
541 541 format = b'!HHHHHH'
542 542 length = struct.calcsize(format)
543 543 info = struct.unpack(
544 544 format, self.data[self.offset : self.offset + length]
545 545 )
546 546 self.offset += length
547 547
548 548 self.id = info[0]
549 549 self.flags = info[1]
550 550 self.numquestions = info[2]
551 551 self.numanswers = info[3]
552 552 self.numauthorities = info[4]
553 553 self.numadditionals = info[5]
554 554
555 555 def readQuestions(self):
556 556 """Reads questions section of packet"""
557 557 format = b'!HH'
558 558 length = struct.calcsize(format)
559 559 for i in range(0, self.numquestions):
560 560 name = self.readName()
561 561 info = struct.unpack(
562 562 format, self.data[self.offset : self.offset + length]
563 563 )
564 564 self.offset += length
565 565
566 566 try:
567 567 question = DNSQuestion(name, info[0], info[1])
568 568 self.questions.append(question)
569 569 except NonLocalNameException:
570 570 pass
571 571
572 572 def readInt(self):
573 573 """Reads an integer from the packet"""
574 574 format = b'!I'
575 575 length = struct.calcsize(format)
576 576 info = struct.unpack(
577 577 format, self.data[self.offset : self.offset + length]
578 578 )
579 579 self.offset += length
580 580 return info[0]
581 581
582 582 def readCharacterString(self):
583 583 """Reads a character string from the packet"""
584 584 length = ord(self.data[self.offset])
585 585 self.offset += 1
586 586 return self.readString(length)
587 587
588 588 def readString(self, len):
589 589 """Reads a string of a given length from the packet"""
590 590 format = b'!%ds' % len
591 591 length = struct.calcsize(format)
592 592 info = struct.unpack(
593 593 format, self.data[self.offset : self.offset + length]
594 594 )
595 595 self.offset += length
596 596 return info[0]
597 597
598 598 def readUnsignedShort(self):
599 599 """Reads an unsigned short from the packet"""
600 600 format = b'!H'
601 601 length = struct.calcsize(format)
602 602 info = struct.unpack(
603 603 format, self.data[self.offset : self.offset + length]
604 604 )
605 605 self.offset += length
606 606 return info[0]
607 607
608 608 def readOthers(self):
609 609 """Reads answers, authorities and additionals section of the packet"""
610 610 format = b'!HHiH'
611 611 length = struct.calcsize(format)
612 612 n = self.numanswers + self.numauthorities + self.numadditionals
613 613 for i in range(0, n):
614 614 domain = self.readName()
615 615 info = struct.unpack(
616 616 format, self.data[self.offset : self.offset + length]
617 617 )
618 618 self.offset += length
619 619
620 620 rec = None
621 621 if info[0] == _TYPE_A:
622 622 rec = DNSAddress(
623 623 domain, info[0], info[1], info[2], self.readString(4)
624 624 )
625 625 elif info[0] == _TYPE_CNAME or info[0] == _TYPE_PTR:
626 626 rec = DNSPointer(
627 627 domain, info[0], info[1], info[2], self.readName()
628 628 )
629 629 elif info[0] == _TYPE_TXT:
630 630 rec = DNSText(
631 631 domain, info[0], info[1], info[2], self.readString(info[3])
632 632 )
633 633 elif info[0] == _TYPE_SRV:
634 634 rec = DNSService(
635 635 domain,
636 636 info[0],
637 637 info[1],
638 638 info[2],
639 639 self.readUnsignedShort(),
640 640 self.readUnsignedShort(),
641 641 self.readUnsignedShort(),
642 642 self.readName(),
643 643 )
644 644 elif info[0] == _TYPE_HINFO:
645 645 rec = DNSHinfo(
646 646 domain,
647 647 info[0],
648 648 info[1],
649 649 info[2],
650 650 self.readCharacterString(),
651 651 self.readCharacterString(),
652 652 )
653 653 elif info[0] == _TYPE_AAAA:
654 654 rec = DNSAddress(
655 655 domain, info[0], info[1], info[2], self.readString(16)
656 656 )
657 657 else:
658 658 # Try to ignore types we don't know about
659 659 # this may mean the rest of the name is
660 660 # unable to be parsed, and may show errors
661 661 # so this is left for debugging. New types
662 662 # encountered need to be parsed properly.
663 663 #
664 664 # print "UNKNOWN TYPE = " + str(info[0])
665 665 # raise BadTypeInNameException
666 666 self.offset += info[3]
667 667
668 668 if rec is not None:
669 669 self.answers.append(rec)
670 670
671 671 def isQuery(self):
672 672 """Returns true if this is a query"""
673 673 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY
674 674
675 675 def isResponse(self):
676 676 """Returns true if this is a response"""
677 677 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
678 678
679 679 def readUTF(self, offset, len):
680 680 """Reads a UTF-8 string of a given length from the packet"""
681 681 return self.data[offset : offset + len].decode('utf-8')
682 682
683 683 def readName(self):
684 684 """Reads a domain name from the packet"""
685 685 result = r''
686 686 off = self.offset
687 687 next = -1
688 688 first = off
689 689
690 690 while True:
691 691 len = ord(self.data[off : off + 1])
692 692 off += 1
693 693 if len == 0:
694 694 break
695 695 t = len & 0xC0
696 696 if t == 0x00:
697 697 result = ''.join((result, self.readUTF(off, len) + '.'))
698 698 off += len
699 699 elif t == 0xC0:
700 700 if next < 0:
701 701 next = off + 1
702 702 off = ((len & 0x3F) << 8) | ord(self.data[off : off + 1])
703 703 if off >= first:
704 704 raise BadDomainNameCircular(off)
705 705 first = off
706 706 else:
707 707 raise BadDomainName(off)
708 708
709 709 if next >= 0:
710 710 self.offset = next
711 711 else:
712 712 self.offset = off
713 713
714 714 return result
715 715
716 716
717 717 class DNSOutgoing:
718 718 """Object representation of an outgoing packet"""
719 719
720 720 def __init__(self, flags, multicast=1):
721 721 self.finished = 0
722 722 self.id = 0
723 723 self.multicast = multicast
724 724 self.flags = flags
725 725 self.names = {}
726 726 self.data = []
727 727 self.size = 12
728 728
729 729 self.questions = []
730 730 self.answers = []
731 731 self.authorities = []
732 732 self.additionals = []
733 733
734 734 def addQuestion(self, record):
735 735 """Adds a question"""
736 736 self.questions.append(record)
737 737
738 738 def addAnswer(self, inp, record):
739 739 """Adds an answer"""
740 740 if not record.suppressedBy(inp):
741 741 self.addAnswerAtTime(record, 0)
742 742
743 743 def addAnswerAtTime(self, record, now):
744 744 """Adds an answer if if does not expire by a certain time"""
745 745 if record is not None:
746 746 if now == 0 or not record.isExpired(now):
747 747 self.answers.append((record, now))
748 748
749 749 def addAuthoritativeAnswer(self, record):
750 750 """Adds an authoritative answer"""
751 751 self.authorities.append(record)
752 752
753 753 def addAdditionalAnswer(self, record):
754 754 """Adds an additional answer"""
755 755 self.additionals.append(record)
756 756
757 757 def writeByte(self, value):
758 758 """Writes a single byte to the packet"""
759 759 format = b'!c'
760 760 self.data.append(struct.pack(format, chr(value)))
761 761 self.size += 1
762 762
763 763 def insertShort(self, index, value):
764 764 """Inserts an unsigned short in a certain position in the packet"""
765 765 format = b'!H'
766 766 self.data.insert(index, struct.pack(format, value))
767 767 self.size += 2
768 768
769 769 def writeShort(self, value):
770 770 """Writes an unsigned short to the packet"""
771 771 format = b'!H'
772 772 self.data.append(struct.pack(format, value))
773 773 self.size += 2
774 774
775 775 def writeInt(self, value):
776 776 """Writes an unsigned integer to the packet"""
777 777 format = b'!I'
778 778 self.data.append(struct.pack(format, int(value)))
779 779 self.size += 4
780 780
781 781 def writeString(self, value, length):
782 782 """Writes a string to the packet"""
783 783 format = '!' + str(length) + 's'
784 784 self.data.append(struct.pack(format, value))
785 785 self.size += length
786 786
787 787 def writeUTF(self, s):
788 788 """Writes a UTF-8 string of a given length to the packet"""
789 789 utfstr = s.encode('utf-8')
790 790 length = len(utfstr)
791 791 if length > 64:
792 792 raise NamePartTooLongException
793 793 self.writeByte(length)
794 794 self.writeString(utfstr, length)
795 795
796 796 def writeName(self, name):
797 797 """Writes a domain name to the packet"""
798 798
799 799 try:
800 800 # Find existing instance of this name in packet
801 801 #
802 802 index = self.names[name]
803 803 except KeyError:
804 804 # No record of this name already, so write it
805 805 # out as normal, recording the location of the name
806 806 # for future pointers to it.
807 807 #
808 808 self.names[name] = self.size
809 809 parts = name.split(b'.')
810 810 if parts[-1] == b'':
811 811 parts = parts[:-1]
812 812 for part in parts:
813 813 self.writeUTF(part)
814 814 self.writeByte(0)
815 815 return
816 816
817 817 # An index was found, so write a pointer to it
818 818 #
819 819 self.writeByte((index >> 8) | 0xC0)
820 820 self.writeByte(index)
821 821
822 822 def writeQuestion(self, question):
823 823 """Writes a question to the packet"""
824 824 self.writeName(question.name)
825 825 self.writeShort(question.type)
826 826 self.writeShort(question.clazz)
827 827
828 828 def writeRecord(self, record, now):
829 829 """Writes a record (answer, authoritative answer, additional) to
830 830 the packet"""
831 831 self.writeName(record.name)
832 832 self.writeShort(record.type)
833 833 if record.unique and self.multicast:
834 834 self.writeShort(record.clazz | _CLASS_UNIQUE)
835 835 else:
836 836 self.writeShort(record.clazz)
837 837 if now == 0:
838 838 self.writeInt(record.ttl)
839 839 else:
840 840 self.writeInt(record.getRemainingTTL(now))
841 841 index = len(self.data)
842 842 # Adjust size for the short we will write before this record
843 843 #
844 844 self.size += 2
845 845 record.write(self)
846 846 self.size -= 2
847 847
848 848 length = len(b''.join(self.data[index:]))
849 849 self.insertShort(index, length) # Here is the short we adjusted for
850 850
851 851 def packet(self):
852 852 """Returns a string containing the packet's bytes
853 853
854 854 No further parts should be added to the packet once this
855 855 is done."""
856 856 if not self.finished:
857 857 self.finished = 1
858 858 for question in self.questions:
859 859 self.writeQuestion(question)
860 860 for answer, time_ in self.answers:
861 861 self.writeRecord(answer, time_)
862 862 for authority in self.authorities:
863 863 self.writeRecord(authority, 0)
864 864 for additional in self.additionals:
865 865 self.writeRecord(additional, 0)
866 866
867 867 self.insertShort(0, len(self.additionals))
868 868 self.insertShort(0, len(self.authorities))
869 869 self.insertShort(0, len(self.answers))
870 870 self.insertShort(0, len(self.questions))
871 871 self.insertShort(0, self.flags)
872 872 if self.multicast:
873 873 self.insertShort(0, 0)
874 874 else:
875 875 self.insertShort(0, self.id)
876 876 return b''.join(self.data)
877 877
878 878
879 879 class DNSCache:
880 880 """A cache of DNS entries"""
881 881
882 882 def __init__(self):
883 883 self.cache = {}
884 884
885 885 def add(self, entry):
886 886 """Adds an entry"""
887 887 try:
888 888 list = self.cache[entry.key]
889 889 except KeyError:
890 890 list = self.cache[entry.key] = []
891 891 list.append(entry)
892 892
893 893 def remove(self, entry):
894 894 """Removes an entry"""
895 895 try:
896 896 list = self.cache[entry.key]
897 897 list.remove(entry)
898 898 except KeyError:
899 899 pass
900 900
901 901 def get(self, entry):
902 902 """Gets an entry by key. Will return None if there is no
903 903 matching entry."""
904 904 try:
905 905 list = self.cache[entry.key]
906 906 return list[list.index(entry)]
907 907 except (KeyError, ValueError):
908 908 return None
909 909
910 910 def getByDetails(self, name, type, clazz):
911 911 """Gets an entry by details. Will return None if there is
912 912 no matching entry."""
913 913 entry = DNSEntry(name, type, clazz)
914 914 return self.get(entry)
915 915
916 916 def entriesWithName(self, name):
917 917 """Returns a list of entries whose key matches the name."""
918 918 try:
919 919 return self.cache[name]
920 920 except KeyError:
921 921 return []
922 922
923 923 def entries(self):
924 924 """Returns a list of all entries"""
925 925 try:
926 926 return list(itertools.chain.from_iterable(self.cache.values()))
927 927 except Exception:
928 928 return []
929 929
930 930
931 931 class Engine(threading.Thread):
932 932 """An engine wraps read access to sockets, allowing objects that
933 933 need to receive data from sockets to be called back when the
934 934 sockets are ready.
935 935
936 936 A reader needs a handle_read() method, which is called when the socket
937 937 it is interested in is ready for reading.
938 938
939 939 Writers are not implemented here, because we only send short
940 940 packets.
941 941 """
942 942
943 943 def __init__(self, zeroconf):
944 944 threading.Thread.__init__(self)
945 945 self.zeroconf = zeroconf
946 946 self.readers = {} # maps socket to reader
947 947 self.timeout = 5
948 948 self.condition = threading.Condition()
949 949 self.start()
950 950
951 951 def run(self):
952 while not globals()[b'_GLOBAL_DONE']:
952 while not globals()['_GLOBAL_DONE']:
953 953 rs = self.getReaders()
954 954 if len(rs) == 0:
955 955 # No sockets to manage, but we wait for the timeout
956 956 # or addition of a socket
957 957 #
958 958 self.condition.acquire()
959 959 self.condition.wait(self.timeout)
960 960 self.condition.release()
961 961 else:
962 962 try:
963 963 rr, wr, er = select.select(rs, [], [], self.timeout)
964 964 for sock in rr:
965 965 try:
966 966 self.readers[sock].handle_read()
967 967 except Exception:
968 if not globals()[b'_GLOBAL_DONE']:
968 if not globals()['_GLOBAL_DONE']:
969 969 traceback.print_exc()
970 970 except Exception:
971 971 pass
972 972
973 973 def getReaders(self):
974 974 self.condition.acquire()
975 975 result = self.readers.keys()
976 976 self.condition.release()
977 977 return result
978 978
979 979 def addReader(self, reader, socket):
980 980 self.condition.acquire()
981 981 self.readers[socket] = reader
982 982 self.condition.notify()
983 983 self.condition.release()
984 984
985 985 def delReader(self, socket):
986 986 self.condition.acquire()
987 987 del self.readers[socket]
988 988 self.condition.notify()
989 989 self.condition.release()
990 990
991 991 def notify(self):
992 992 self.condition.acquire()
993 993 self.condition.notify()
994 994 self.condition.release()
995 995
996 996
997 997 class Listener:
998 998 """A Listener is used by this module to listen on the multicast
999 999 group to which DNS messages are sent, allowing the implementation
1000 1000 to cache information as it arrives.
1001 1001
1002 1002 It requires registration with an Engine object in order to have
1003 1003 the read() method called when a socket is available for reading."""
1004 1004
1005 1005 def __init__(self, zeroconf):
1006 1006 self.zeroconf = zeroconf
1007 1007 self.zeroconf.engine.addReader(self, self.zeroconf.socket)
1008 1008
1009 1009 def handle_read(self):
1010 1010 sock = self.zeroconf.socket
1011 1011 try:
1012 1012 data, (addr, port) = sock.recvfrom(_MAX_MSG_ABSOLUTE)
1013 1013 except socket.error as e:
1014 1014 if e.errno == errno.EBADF:
1015 1015 # some other thread may close the socket
1016 1016 return
1017 1017 else:
1018 1018 raise
1019 1019 self.data = data
1020 1020 msg = DNSIncoming(data)
1021 1021 if msg.isQuery():
1022 1022 # Always multicast responses
1023 1023 #
1024 1024 if port == _MDNS_PORT:
1025 1025 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT)
1026 1026 # If it's not a multicast query, reply via unicast
1027 1027 # and multicast
1028 1028 #
1029 1029 elif port == _DNS_PORT:
1030 1030 self.zeroconf.handleQuery(msg, addr, port)
1031 1031 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT)
1032 1032 else:
1033 1033 self.zeroconf.handleResponse(msg)
1034 1034
1035 1035
1036 1036 class Reaper(threading.Thread):
1037 1037 """A Reaper is used by this module to remove cache entries that
1038 1038 have expired."""
1039 1039
1040 1040 def __init__(self, zeroconf):
1041 1041 threading.Thread.__init__(self)
1042 1042 self.zeroconf = zeroconf
1043 1043 self.start()
1044 1044
1045 1045 def run(self):
1046 1046 while True:
1047 1047 self.zeroconf.wait(10 * 1000)
1048 if globals()[b'_GLOBAL_DONE']:
1048 if globals()['_GLOBAL_DONE']:
1049 1049 return
1050 1050 now = currentTimeMillis()
1051 1051 for record in self.zeroconf.cache.entries():
1052 1052 if record.isExpired(now):
1053 1053 self.zeroconf.updateRecord(now, record)
1054 1054 self.zeroconf.cache.remove(record)
1055 1055
1056 1056
1057 1057 class ServiceBrowser(threading.Thread):
1058 1058 """Used to browse for a service of a specific type.
1059 1059
1060 1060 The listener object will have its addService() and
1061 1061 removeService() methods called when this browser
1062 1062 discovers changes in the services availability."""
1063 1063
1064 1064 def __init__(self, zeroconf, type, listener):
1065 1065 """Creates a browser for a specific type"""
1066 1066 threading.Thread.__init__(self)
1067 1067 self.zeroconf = zeroconf
1068 1068 self.type = type
1069 1069 self.listener = listener
1070 1070 self.services = {}
1071 1071 self.nexttime = currentTimeMillis()
1072 1072 self.delay = _BROWSER_TIME
1073 1073 self.list = []
1074 1074
1075 1075 self.done = 0
1076 1076
1077 1077 self.zeroconf.addListener(
1078 1078 self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN)
1079 1079 )
1080 1080 self.start()
1081 1081
1082 1082 def updateRecord(self, zeroconf, now, record):
1083 1083 """Callback invoked by Zeroconf when new information arrives.
1084 1084
1085 1085 Updates information required by browser in the Zeroconf cache."""
1086 1086 if record.type == _TYPE_PTR and record.name == self.type:
1087 1087 expired = record.isExpired(now)
1088 1088 try:
1089 1089 oldrecord = self.services[record.alias.lower()]
1090 1090 if not expired:
1091 1091 oldrecord.resetTTL(record)
1092 1092 else:
1093 1093 del self.services[record.alias.lower()]
1094 1094 callback = lambda x: self.listener.removeService(
1095 1095 x, self.type, record.alias
1096 1096 )
1097 1097 self.list.append(callback)
1098 1098 return
1099 1099 except Exception:
1100 1100 if not expired:
1101 1101 self.services[record.alias.lower()] = record
1102 1102 callback = lambda x: self.listener.addService(
1103 1103 x, self.type, record.alias
1104 1104 )
1105 1105 self.list.append(callback)
1106 1106
1107 1107 expires = record.getExpirationTime(75)
1108 1108 if expires < self.nexttime:
1109 1109 self.nexttime = expires
1110 1110
1111 1111 def cancel(self):
1112 1112 self.done = 1
1113 1113 self.zeroconf.notifyAll()
1114 1114
1115 1115 def run(self):
1116 1116 while True:
1117 1117 event = None
1118 1118 now = currentTimeMillis()
1119 1119 if len(self.list) == 0 and self.nexttime > now:
1120 1120 self.zeroconf.wait(self.nexttime - now)
1121 if globals()[b'_GLOBAL_DONE'] or self.done:
1121 if globals()['_GLOBAL_DONE'] or self.done:
1122 1122 return
1123 1123 now = currentTimeMillis()
1124 1124
1125 1125 if self.nexttime <= now:
1126 1126 out = DNSOutgoing(_FLAGS_QR_QUERY)
1127 1127 out.addQuestion(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
1128 1128 for record in self.services.values():
1129 1129 if not record.isExpired(now):
1130 1130 out.addAnswerAtTime(record, now)
1131 1131 self.zeroconf.send(out)
1132 1132 self.nexttime = now + self.delay
1133 1133 self.delay = min(20 * 1000, self.delay * 2)
1134 1134
1135 1135 if len(self.list) > 0:
1136 1136 event = self.list.pop(0)
1137 1137
1138 1138 if event is not None:
1139 1139 event(self.zeroconf)
1140 1140
1141 1141
1142 1142 class ServiceInfo:
1143 1143 """Service information"""
1144 1144
1145 1145 def __init__(
1146 1146 self,
1147 1147 type,
1148 1148 name,
1149 1149 address=None,
1150 1150 port=None,
1151 1151 weight=0,
1152 1152 priority=0,
1153 1153 properties=None,
1154 1154 server=None,
1155 1155 ):
1156 1156 """Create a service description.
1157 1157
1158 1158 type: fully qualified service type name
1159 1159 name: fully qualified service name
1160 1160 address: IP address as unsigned short, network byte order
1161 1161 port: port that the service runs on
1162 1162 weight: weight of the service
1163 1163 priority: priority of the service
1164 1164 properties: dictionary of properties (or a string holding the bytes for
1165 1165 the text field)
1166 1166 server: fully qualified name for service host (defaults to name)"""
1167 1167
1168 1168 if not name.endswith(type):
1169 1169 raise BadTypeInNameException
1170 1170 self.type = type
1171 1171 self.name = name
1172 1172 self.address = address
1173 1173 self.port = port
1174 1174 self.weight = weight
1175 1175 self.priority = priority
1176 1176 if server:
1177 1177 self.server = server
1178 1178 else:
1179 1179 self.server = name
1180 1180 self.setProperties(properties)
1181 1181
1182 1182 def setProperties(self, properties):
1183 1183 """Sets properties and text of this info from a dictionary"""
1184 1184 if isinstance(properties, dict):
1185 1185 self.properties = properties
1186 1186 list = []
1187 1187 result = b''
1188 1188 for key in properties:
1189 1189 value = properties[key]
1190 1190 if value is None:
1191 1191 suffix = b''
1192 1192 elif isinstance(value, str):
1193 1193 suffix = value
1194 1194 elif isinstance(value, int):
1195 1195 if value:
1196 1196 suffix = b'true'
1197 1197 else:
1198 1198 suffix = b'false'
1199 1199 else:
1200 1200 suffix = b''
1201 1201 list.append(b'='.join((key, suffix)))
1202 1202 for item in list:
1203 1203 result = b''.join(
1204 1204 (
1205 1205 result,
1206 1206 struct.pack(b'!c', pycompat.bytechr(len(item))),
1207 1207 item,
1208 1208 )
1209 1209 )
1210 1210 self.text = result
1211 1211 else:
1212 1212 self.text = properties
1213 1213
1214 1214 def setText(self, text):
1215 1215 """Sets properties and text given a text field"""
1216 1216 self.text = text
1217 1217 try:
1218 1218 result = {}
1219 1219 end = len(text)
1220 1220 index = 0
1221 1221 strs = []
1222 1222 while index < end:
1223 1223 length = ord(text[index])
1224 1224 index += 1
1225 1225 strs.append(text[index : index + length])
1226 1226 index += length
1227 1227
1228 1228 for s in strs:
1229 1229 eindex = s.find(b'=')
1230 1230 if eindex == -1:
1231 1231 # No equals sign at all
1232 1232 key = s
1233 1233 value = 0
1234 1234 else:
1235 1235 key = s[:eindex]
1236 1236 value = s[eindex + 1 :]
1237 1237 if value == b'true':
1238 1238 value = 1
1239 1239 elif value == b'false' or not value:
1240 1240 value = 0
1241 1241
1242 1242 # Only update non-existent properties
1243 1243 if key and result.get(key) is None:
1244 1244 result[key] = value
1245 1245
1246 1246 self.properties = result
1247 1247 except Exception:
1248 1248 traceback.print_exc()
1249 1249 self.properties = None
1250 1250
1251 1251 def getType(self):
1252 1252 """Type accessor"""
1253 1253 return self.type
1254 1254
1255 1255 def getName(self):
1256 1256 """Name accessor"""
1257 1257 if self.type is not None and self.name.endswith(b"." + self.type):
1258 1258 return self.name[: len(self.name) - len(self.type) - 1]
1259 1259 return self.name
1260 1260
1261 1261 def getAddress(self):
1262 1262 """Address accessor"""
1263 1263 return self.address
1264 1264
1265 1265 def getPort(self):
1266 1266 """Port accessor"""
1267 1267 return self.port
1268 1268
1269 1269 def getPriority(self):
1270 1270 """Priority accessor"""
1271 1271 return self.priority
1272 1272
1273 1273 def getWeight(self):
1274 1274 """Weight accessor"""
1275 1275 return self.weight
1276 1276
1277 1277 def getProperties(self):
1278 1278 """Properties accessor"""
1279 1279 return self.properties
1280 1280
1281 1281 def getText(self):
1282 1282 """Text accessor"""
1283 1283 return self.text
1284 1284
1285 1285 def getServer(self):
1286 1286 """Server accessor"""
1287 1287 return self.server
1288 1288
1289 1289 def updateRecord(self, zeroconf, now, record):
1290 1290 """Updates service information from a DNS record"""
1291 1291 if record is not None and not record.isExpired(now):
1292 1292 if record.type == _TYPE_A:
1293 1293 # if record.name == self.name:
1294 1294 if record.name == self.server:
1295 1295 self.address = record.address
1296 1296 elif record.type == _TYPE_SRV:
1297 1297 if record.name == self.name:
1298 1298 self.server = record.server
1299 1299 self.port = record.port
1300 1300 self.weight = record.weight
1301 1301 self.priority = record.priority
1302 1302 # self.address = None
1303 1303 self.updateRecord(
1304 1304 zeroconf,
1305 1305 now,
1306 1306 zeroconf.cache.getByDetails(
1307 1307 self.server, _TYPE_A, _CLASS_IN
1308 1308 ),
1309 1309 )
1310 1310 elif record.type == _TYPE_TXT:
1311 1311 if record.name == self.name:
1312 1312 self.setText(record.text)
1313 1313
1314 1314 def request(self, zeroconf, timeout):
1315 1315 """Returns true if the service could be discovered on the
1316 1316 network, and updates this object with details discovered.
1317 1317 """
1318 1318 now = currentTimeMillis()
1319 1319 delay = _LISTENER_TIME
1320 1320 next = now + delay
1321 1321 last = now + timeout
1322 1322 result = False
1323 1323 try:
1324 1324 zeroconf.addListener(
1325 1325 self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN)
1326 1326 )
1327 1327 while (
1328 1328 self.server is None or self.address is None or self.text is None
1329 1329 ):
1330 1330 if last <= now:
1331 1331 return 0
1332 1332 if next <= now:
1333 1333 out = DNSOutgoing(_FLAGS_QR_QUERY)
1334 1334 out.addQuestion(
1335 1335 DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN)
1336 1336 )
1337 1337 out.addAnswerAtTime(
1338 1338 zeroconf.cache.getByDetails(
1339 1339 self.name, _TYPE_SRV, _CLASS_IN
1340 1340 ),
1341 1341 now,
1342 1342 )
1343 1343 out.addQuestion(
1344 1344 DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN)
1345 1345 )
1346 1346 out.addAnswerAtTime(
1347 1347 zeroconf.cache.getByDetails(
1348 1348 self.name, _TYPE_TXT, _CLASS_IN
1349 1349 ),
1350 1350 now,
1351 1351 )
1352 1352 if self.server is not None:
1353 1353 out.addQuestion(
1354 1354 DNSQuestion(self.server, _TYPE_A, _CLASS_IN)
1355 1355 )
1356 1356 out.addAnswerAtTime(
1357 1357 zeroconf.cache.getByDetails(
1358 1358 self.server, _TYPE_A, _CLASS_IN
1359 1359 ),
1360 1360 now,
1361 1361 )
1362 1362 zeroconf.send(out)
1363 1363 next = now + delay
1364 1364 delay = delay * 2
1365 1365
1366 1366 zeroconf.wait(min(next, last) - now)
1367 1367 now = currentTimeMillis()
1368 1368 result = True
1369 1369 finally:
1370 1370 zeroconf.removeListener(self)
1371 1371
1372 1372 return result
1373 1373
1374 1374 def __eq__(self, other):
1375 1375 """Tests equality of service name"""
1376 1376 if isinstance(other, ServiceInfo):
1377 1377 return other.name == self.name
1378 1378 return 0
1379 1379
1380 1380 def __ne__(self, other):
1381 1381 """Non-equality test"""
1382 1382 return not self.__eq__(other)
1383 1383
1384 1384 def __repr__(self):
1385 1385 """String representation"""
1386 1386 result = b"service[%s,%s:%s," % (
1387 1387 self.name,
1388 1388 socket.inet_ntoa(self.getAddress()),
1389 1389 self.port,
1390 1390 )
1391 1391 if self.text is None:
1392 1392 result += b"None"
1393 1393 else:
1394 1394 if len(self.text) < 20:
1395 1395 result += self.text
1396 1396 else:
1397 1397 result += self.text[:17] + b"..."
1398 1398 result += b"]"
1399 1399 return result
1400 1400
1401 1401
1402 1402 class Zeroconf:
1403 1403 """Implementation of Zeroconf Multicast DNS Service Discovery
1404 1404
1405 1405 Supports registration, unregistration, queries and browsing.
1406 1406 """
1407 1407
1408 1408 def __init__(self, bindaddress=None):
1409 1409 """Creates an instance of the Zeroconf class, establishing
1410 1410 multicast communications, listening and reaping threads."""
1411 globals()[b'_GLOBAL_DONE'] = 0
1411 globals()['_GLOBAL_DONE'] = 0
1412 1412 if bindaddress is None:
1413 1413 self.intf = socket.gethostbyname(socket.gethostname())
1414 1414 else:
1415 1415 self.intf = bindaddress
1416 1416 self.group = (b'', _MDNS_PORT)
1417 1417 self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1418 1418 try:
1419 1419 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1420 1420 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1421 1421 except Exception:
1422 1422 # SO_REUSEADDR should be equivalent to SO_REUSEPORT for
1423 1423 # multicast UDP sockets (p 731, "TCP/IP Illustrated,
1424 1424 # Volume 2"), but some BSD-derived systems require
1425 1425 # SO_REUSEPORT to be specified explicitly. Also, not all
1426 1426 # versions of Python have SO_REUSEPORT available. So
1427 1427 # if you're on a BSD-based system, and haven't upgraded
1428 1428 # to Python 2.3 yet, you may find this library doesn't
1429 1429 # work as expected.
1430 1430 #
1431 1431 pass
1432 1432 self.socket.setsockopt(_SOL_IP, socket.IP_MULTICAST_TTL, b"\xff")
1433 1433 self.socket.setsockopt(_SOL_IP, socket.IP_MULTICAST_LOOP, b"\x01")
1434 1434 try:
1435 1435 self.socket.bind(self.group)
1436 1436 except Exception:
1437 1437 # Some versions of linux raise an exception even though
1438 1438 # SO_REUSEADDR and SO_REUSEPORT have been set, so ignore it
1439 1439 pass
1440 1440 self.socket.setsockopt(
1441 1441 _SOL_IP,
1442 1442 socket.IP_ADD_MEMBERSHIP,
1443 1443 socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0'),
1444 1444 )
1445 1445
1446 1446 self.listeners = []
1447 1447 self.browsers = []
1448 1448 self.services = {}
1449 1449 self.servicetypes = {}
1450 1450
1451 1451 self.cache = DNSCache()
1452 1452
1453 1453 self.condition = threading.Condition()
1454 1454
1455 1455 self.engine = Engine(self)
1456 1456 self.listener = Listener(self)
1457 1457 self.reaper = Reaper(self)
1458 1458
1459 1459 def isLoopback(self):
1460 1460 return self.intf.startswith(b"127.0.0.1")
1461 1461
1462 1462 def isLinklocal(self):
1463 1463 return self.intf.startswith(b"169.254.")
1464 1464
1465 1465 def wait(self, timeout):
1466 1466 """Calling thread waits for a given number of milliseconds or
1467 1467 until notified."""
1468 1468 self.condition.acquire()
1469 1469 self.condition.wait(timeout / 1000)
1470 1470 self.condition.release()
1471 1471
1472 1472 def notifyAll(self):
1473 1473 """Notifies all waiting threads"""
1474 1474 self.condition.acquire()
1475 1475 self.condition.notify_all()
1476 1476 self.condition.release()
1477 1477
1478 1478 def getServiceInfo(self, type, name, timeout=3000):
1479 1479 """Returns network's service information for a particular
1480 1480 name and type, or None if no service matches by the timeout,
1481 1481 which defaults to 3 seconds."""
1482 1482 info = ServiceInfo(type, name)
1483 1483 if info.request(self, timeout):
1484 1484 return info
1485 1485 return None
1486 1486
1487 1487 def addServiceListener(self, type, listener):
1488 1488 """Adds a listener for a particular service type. This object
1489 1489 will then have its updateRecord method called when information
1490 1490 arrives for that type."""
1491 1491 self.removeServiceListener(listener)
1492 1492 self.browsers.append(ServiceBrowser(self, type, listener))
1493 1493
1494 1494 def removeServiceListener(self, listener):
1495 1495 """Removes a listener from the set that is currently listening."""
1496 1496 for browser in self.browsers:
1497 1497 if browser.listener == listener:
1498 1498 browser.cancel()
1499 1499 del browser
1500 1500
1501 1501 def registerService(self, info, ttl=_DNS_TTL):
1502 1502 """Registers service information to the network with a default TTL
1503 1503 of 60 seconds. Zeroconf will then respond to requests for
1504 1504 information for that service. The name of the service may be
1505 1505 changed if needed to make it unique on the network."""
1506 1506 self.checkService(info)
1507 1507 self.services[info.name.lower()] = info
1508 1508 if info.type in self.servicetypes:
1509 1509 self.servicetypes[info.type] += 1
1510 1510 else:
1511 1511 self.servicetypes[info.type] = 1
1512 1512 now = currentTimeMillis()
1513 1513 nexttime = now
1514 1514 i = 0
1515 1515 while i < 3:
1516 1516 if now < nexttime:
1517 1517 self.wait(nexttime - now)
1518 1518 now = currentTimeMillis()
1519 1519 continue
1520 1520 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1521 1521 out.addAnswerAtTime(
1522 1522 DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, ttl, info.name), 0
1523 1523 )
1524 1524 out.addAnswerAtTime(
1525 1525 DNSService(
1526 1526 info.name,
1527 1527 _TYPE_SRV,
1528 1528 _CLASS_IN,
1529 1529 ttl,
1530 1530 info.priority,
1531 1531 info.weight,
1532 1532 info.port,
1533 1533 info.server,
1534 1534 ),
1535 1535 0,
1536 1536 )
1537 1537 out.addAnswerAtTime(
1538 1538 DNSText(info.name, _TYPE_TXT, _CLASS_IN, ttl, info.text), 0
1539 1539 )
1540 1540 if info.address:
1541 1541 out.addAnswerAtTime(
1542 1542 DNSAddress(
1543 1543 info.server, _TYPE_A, _CLASS_IN, ttl, info.address
1544 1544 ),
1545 1545 0,
1546 1546 )
1547 1547 self.send(out)
1548 1548 i += 1
1549 1549 nexttime += _REGISTER_TIME
1550 1550
1551 1551 def unregisterService(self, info):
1552 1552 """Unregister a service."""
1553 1553 try:
1554 1554 del self.services[info.name.lower()]
1555 1555 if self.servicetypes[info.type] > 1:
1556 1556 self.servicetypes[info.type] -= 1
1557 1557 else:
1558 1558 del self.servicetypes[info.type]
1559 1559 except KeyError:
1560 1560 pass
1561 1561 now = currentTimeMillis()
1562 1562 nexttime = now
1563 1563 i = 0
1564 1564 while i < 3:
1565 1565 if now < nexttime:
1566 1566 self.wait(nexttime - now)
1567 1567 now = currentTimeMillis()
1568 1568 continue
1569 1569 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1570 1570 out.addAnswerAtTime(
1571 1571 DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0
1572 1572 )
1573 1573 out.addAnswerAtTime(
1574 1574 DNSService(
1575 1575 info.name,
1576 1576 _TYPE_SRV,
1577 1577 _CLASS_IN,
1578 1578 0,
1579 1579 info.priority,
1580 1580 info.weight,
1581 1581 info.port,
1582 1582 info.name,
1583 1583 ),
1584 1584 0,
1585 1585 )
1586 1586 out.addAnswerAtTime(
1587 1587 DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0
1588 1588 )
1589 1589 if info.address:
1590 1590 out.addAnswerAtTime(
1591 1591 DNSAddress(
1592 1592 info.server, _TYPE_A, _CLASS_IN, 0, info.address
1593 1593 ),
1594 1594 0,
1595 1595 )
1596 1596 self.send(out)
1597 1597 i += 1
1598 1598 nexttime += _UNREGISTER_TIME
1599 1599
1600 1600 def unregisterAllServices(self):
1601 1601 """Unregister all registered services."""
1602 1602 if len(self.services) > 0:
1603 1603 now = currentTimeMillis()
1604 1604 nexttime = now
1605 1605 i = 0
1606 1606 while i < 3:
1607 1607 if now < nexttime:
1608 1608 self.wait(nexttime - now)
1609 1609 now = currentTimeMillis()
1610 1610 continue
1611 1611 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1612 1612 for info in self.services.values():
1613 1613 out.addAnswerAtTime(
1614 1614 DNSPointer(
1615 1615 info.type, _TYPE_PTR, _CLASS_IN, 0, info.name
1616 1616 ),
1617 1617 0,
1618 1618 )
1619 1619 out.addAnswerAtTime(
1620 1620 DNSService(
1621 1621 info.name,
1622 1622 _TYPE_SRV,
1623 1623 _CLASS_IN,
1624 1624 0,
1625 1625 info.priority,
1626 1626 info.weight,
1627 1627 info.port,
1628 1628 info.server,
1629 1629 ),
1630 1630 0,
1631 1631 )
1632 1632 out.addAnswerAtTime(
1633 1633 DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text),
1634 1634 0,
1635 1635 )
1636 1636 if info.address:
1637 1637 out.addAnswerAtTime(
1638 1638 DNSAddress(
1639 1639 info.server, _TYPE_A, _CLASS_IN, 0, info.address
1640 1640 ),
1641 1641 0,
1642 1642 )
1643 1643 self.send(out)
1644 1644 i += 1
1645 1645 nexttime += _UNREGISTER_TIME
1646 1646
1647 1647 def checkService(self, info):
1648 1648 """Checks the network for a unique service name, modifying the
1649 1649 ServiceInfo passed in if it is not unique."""
1650 1650 now = currentTimeMillis()
1651 1651 nexttime = now
1652 1652 i = 0
1653 1653 while i < 3:
1654 1654 for record in self.cache.entriesWithName(info.type):
1655 1655 if (
1656 1656 record.type == _TYPE_PTR
1657 1657 and not record.isExpired(now)
1658 1658 and record.alias == info.name
1659 1659 ):
1660 1660 if info.name.find(b'.') < 0:
1661 1661 info.name = b"%s.[%s:%d].%s" % (
1662 1662 info.name,
1663 1663 info.address,
1664 1664 info.port,
1665 1665 info.type,
1666 1666 )
1667 1667 self.checkService(info)
1668 1668 return
1669 1669 raise NonUniqueNameException
1670 1670 if now < nexttime:
1671 1671 self.wait(nexttime - now)
1672 1672 now = currentTimeMillis()
1673 1673 continue
1674 1674 out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
1675 1675 self.debug = out
1676 1676 out.addQuestion(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN))
1677 1677 out.addAuthoritativeAnswer(
1678 1678 DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, _DNS_TTL, info.name)
1679 1679 )
1680 1680 self.send(out)
1681 1681 i += 1
1682 1682 nexttime += _CHECK_TIME
1683 1683
1684 1684 def addListener(self, listener, question):
1685 1685 """Adds a listener for a given question. The listener will have
1686 1686 its updateRecord method called when information is available to
1687 1687 answer the question."""
1688 1688 now = currentTimeMillis()
1689 1689 self.listeners.append(listener)
1690 1690 if question is not None:
1691 1691 for record in self.cache.entriesWithName(question.name):
1692 1692 if question.answeredBy(record) and not record.isExpired(now):
1693 1693 listener.updateRecord(self, now, record)
1694 1694 self.notifyAll()
1695 1695
1696 1696 def removeListener(self, listener):
1697 1697 """Removes a listener."""
1698 1698 try:
1699 1699 self.listeners.remove(listener)
1700 1700 self.notifyAll()
1701 1701 except Exception:
1702 1702 pass
1703 1703
1704 1704 def updateRecord(self, now, rec):
1705 1705 """Used to notify listeners of new information that has updated
1706 1706 a record."""
1707 1707 for listener in self.listeners:
1708 1708 listener.updateRecord(self, now, rec)
1709 1709 self.notifyAll()
1710 1710
1711 1711 def handleResponse(self, msg):
1712 1712 """Deal with incoming response packets. All answers
1713 1713 are held in the cache, and listeners are notified."""
1714 1714 now = currentTimeMillis()
1715 1715 for record in msg.answers:
1716 1716 expired = record.isExpired(now)
1717 1717 if record in self.cache.entries():
1718 1718 if expired:
1719 1719 self.cache.remove(record)
1720 1720 else:
1721 1721 entry = self.cache.get(record)
1722 1722 if entry is not None:
1723 1723 entry.resetTTL(record)
1724 1724 record = entry
1725 1725 else:
1726 1726 self.cache.add(record)
1727 1727
1728 1728 self.updateRecord(now, record)
1729 1729
1730 1730 def handleQuery(self, msg, addr, port):
1731 1731 """Deal with incoming query packets. Provides a response if
1732 1732 possible."""
1733 1733 out = None
1734 1734
1735 1735 # Support unicast client responses
1736 1736 #
1737 1737 if port != _MDNS_PORT:
1738 1738 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, 0)
1739 1739 for question in msg.questions:
1740 1740 out.addQuestion(question)
1741 1741
1742 1742 for question in msg.questions:
1743 1743 if question.type == _TYPE_PTR:
1744 1744 if question.name == b"_services._dns-sd._udp.local.":
1745 1745 for stype in self.servicetypes.keys():
1746 1746 if out is None:
1747 1747 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1748 1748 out.addAnswer(
1749 1749 msg,
1750 1750 DNSPointer(
1751 1751 b"_services._dns-sd._udp.local.",
1752 1752 _TYPE_PTR,
1753 1753 _CLASS_IN,
1754 1754 _DNS_TTL,
1755 1755 stype,
1756 1756 ),
1757 1757 )
1758 1758 for service in self.services.values():
1759 1759 if question.name == service.type:
1760 1760 if out is None:
1761 1761 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1762 1762 out.addAnswer(
1763 1763 msg,
1764 1764 DNSPointer(
1765 1765 service.type,
1766 1766 _TYPE_PTR,
1767 1767 _CLASS_IN,
1768 1768 _DNS_TTL,
1769 1769 service.name,
1770 1770 ),
1771 1771 )
1772 1772 else:
1773 1773 try:
1774 1774 if out is None:
1775 1775 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
1776 1776
1777 1777 # Answer A record queries for any service addresses we know
1778 1778 if question.type == _TYPE_A or question.type == _TYPE_ANY:
1779 1779 for service in self.services.values():
1780 1780 if service.server == question.name.lower():
1781 1781 out.addAnswer(
1782 1782 msg,
1783 1783 DNSAddress(
1784 1784 question.name,
1785 1785 _TYPE_A,
1786 1786 _CLASS_IN | _CLASS_UNIQUE,
1787 1787 _DNS_TTL,
1788 1788 service.address,
1789 1789 ),
1790 1790 )
1791 1791
1792 1792 service = self.services.get(question.name.lower(), None)
1793 1793 if not service:
1794 1794 continue
1795 1795
1796 1796 if question.type == _TYPE_SRV or question.type == _TYPE_ANY:
1797 1797 out.addAnswer(
1798 1798 msg,
1799 1799 DNSService(
1800 1800 question.name,
1801 1801 _TYPE_SRV,
1802 1802 _CLASS_IN | _CLASS_UNIQUE,
1803 1803 _DNS_TTL,
1804 1804 service.priority,
1805 1805 service.weight,
1806 1806 service.port,
1807 1807 service.server,
1808 1808 ),
1809 1809 )
1810 1810 if question.type == _TYPE_TXT or question.type == _TYPE_ANY:
1811 1811 out.addAnswer(
1812 1812 msg,
1813 1813 DNSText(
1814 1814 question.name,
1815 1815 _TYPE_TXT,
1816 1816 _CLASS_IN | _CLASS_UNIQUE,
1817 1817 _DNS_TTL,
1818 1818 service.text,
1819 1819 ),
1820 1820 )
1821 1821 if question.type == _TYPE_SRV:
1822 1822 out.addAdditionalAnswer(
1823 1823 DNSAddress(
1824 1824 service.server,
1825 1825 _TYPE_A,
1826 1826 _CLASS_IN | _CLASS_UNIQUE,
1827 1827 _DNS_TTL,
1828 1828 service.address,
1829 1829 )
1830 1830 )
1831 1831 except Exception:
1832 1832 traceback.print_exc()
1833 1833
1834 1834 if out is not None and out.answers:
1835 1835 out.id = msg.id
1836 1836 self.send(out, addr, port)
1837 1837
1838 1838 def send(self, out, addr=_MDNS_ADDR, port=_MDNS_PORT):
1839 1839 """Sends an outgoing packet."""
1840 1840 # This is a quick test to see if we can parse the packets we generate
1841 1841 # temp = DNSIncoming(out.packet())
1842 1842 try:
1843 1843 self.socket.sendto(out.packet(), 0, (addr, port))
1844 1844 except Exception:
1845 1845 # Ignore this, it may be a temporary loss of network connection
1846 1846 pass
1847 1847
1848 1848 def close(self):
1849 1849 """Ends the background threads, and prevent this instance from
1850 1850 servicing further queries."""
1851 if globals()[b'_GLOBAL_DONE'] == 0:
1852 globals()[b'_GLOBAL_DONE'] = 1
1851 if globals()['_GLOBAL_DONE'] == 0:
1852 globals()['_GLOBAL_DONE'] = 1
1853 1853 self.notifyAll()
1854 1854 self.engine.notify()
1855 1855 self.unregisterAllServices()
1856 1856 self.socket.setsockopt(
1857 1857 _SOL_IP,
1858 1858 socket.IP_DROP_MEMBERSHIP,
1859 1859 socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0'),
1860 1860 )
1861 1861 self.socket.close()
1862 1862
1863 1863
1864 1864 # Test a few module features, including service registration, service
1865 1865 # query (for Zoe), and service unregistration.
1866 1866
1867 1867 if __name__ == '__main__':
1868 1868 print(b"Multicast DNS Service Discovery for Python, version", __version__)
1869 1869 r = Zeroconf()
1870 1870 print(b"1. Testing registration of a service...")
1871 1871 desc = {b'version': b'0.10', b'a': b'test value', b'b': b'another value'}
1872 1872 info = ServiceInfo(
1873 1873 b"_http._tcp.local.",
1874 1874 b"My Service Name._http._tcp.local.",
1875 1875 socket.inet_aton("127.0.0.1"),
1876 1876 1234,
1877 1877 0,
1878 1878 0,
1879 1879 desc,
1880 1880 )
1881 1881 print(b" Registering service...")
1882 1882 r.registerService(info)
1883 1883 print(b" Registration done.")
1884 1884 print(b"2. Testing query of service information...")
1885 1885 print(
1886 1886 b" Getting ZOE service:",
1887 1887 str(r.getServiceInfo(b"_http._tcp.local.", b"ZOE._http._tcp.local.")),
1888 1888 )
1889 1889 print(b" Query done.")
1890 1890 print(b"3. Testing query of own service...")
1891 1891 print(
1892 1892 b" Getting self:",
1893 1893 str(
1894 1894 r.getServiceInfo(
1895 1895 b"_http._tcp.local.", b"My Service Name._http._tcp.local."
1896 1896 )
1897 1897 ),
1898 1898 )
1899 1899 print(b" Query done.")
1900 1900 print(b"4. Testing unregister of service information...")
1901 1901 r.unregisterService(info)
1902 1902 print(b" Unregister done.")
1903 1903 r.close()
General Comments 0
You need to be logged in to leave comments. Login now