Rewrite of KRPC / DHT classes

KRPC: Using own socket wrapper to provide needed (non-)/blocking behaviour of send/recv
DHT: Improved handling of malformed messages
Improved granularity of logging in DHT / KRPC and utils
Small performance fixes
Replaced lambdas by functions to facilitate code instrumentation
Decreased the shutdown time for KRPC nodes
bencode python 3
crc32 python 3
dht python 3
master
Fred Stober 2015-11-09 16:57:07 +01:00
parent 553fe75e72
commit 06dd22bec7
6 changed files with 690 additions and 457 deletions

View File

@ -41,19 +41,20 @@ an async result holder with the unprocessed data from the remote host:
In addition, some additional helper functions are made available - these
functions take care of updating the routing table and are blocking calls with
a user specified timeout:
- dht_ping(connection, timeout = 1)
- dht_ping(connection, timeout = 5)
Returns the complete result dictionary of the call.
- dht_find_node(search_id, timeout = 120)
- dht_find_node(search_id, timeout = 5, retries = 2)
Searches iteratively for nodes with the given id
and yields the connection tuple if found.
- dht_get_peers(info_hash, timeout = 120)
- dht_get_peers(info_hash, timeout = 5, retries = 2)
Searches iteratively for nodes with the given info_hash
and yields the connection tuple if found.
- dht_announce_peer(info_hash)
- dht_announce_peer(info_hash, implied_port = 1)
Registers the availabilty of the info_hash on this node
to all peers that supplied a token while searching for it.
The final two functions are used to start and shutdown the local DHT Peer:
The final three functions are used to start and shutdown the local DHT Peer
and allow access to the discovered external connection infos:
- __init__(listen_connection, bootstrap_connection = ('router.bittorrent.com', 6881),
setup = {'report_t': 10, 'check_t': 30, 'check_N': 10, 'discover_t': 180})
The constructor needs to know what address and port to listen on and which node to use
@ -61,3 +62,5 @@ The final two functions are used to start and shutdown the local DHT Peer:
threads can be configured as well.
- shutdown()
Start shutdown of the local DHT peer and all associated maintainance threads.
- get_external_connection()
Return the discovered external connection infos

View File

@ -1,129 +1,115 @@
# The contents of this file are subject to the BitTorrent Open Source License
# Version 1.1 (the License). You may not copy or use this file, in either
# source code or executable form, except in compliance with the License. You
# may obtain a copy of the License at http://www.bittorrent.com/license/.
#
# Software distributed under the License is distributed on an AS IS basis,
# WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
# for the specific language governing rights and limitations under the
# License.
"""
The MIT License
# Written by Petru Paler
Copyright (c) 2015 Fred Stober
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
class BTFailure(Exception):
pass
pass
def decode_int(x, f):
f += 1
newf = x.index('e', f)
n = int(x[f:newf])
if x[f] == '-':
if x[f + 1] == '0':
raise ValueError
elif x[f] == '0' and newf != f+1:
raise ValueError
return (n, newf+1)
# Encoding functions ##############################################
def decode_string(x, f):
colon = x.index(':', f)
n = int(x[f:colon])
if x[f] == '0' and colon != f+1:
raise ValueError
colon += 1
return (x[colon:colon+n], colon+n)
import sys
if sys.version_info.major >= 3:
str_to_bytes = lambda x: x.encode('ascii')
else:
str_to_bytes = lambda x: x
def decode_list(x, f):
r, f = [], f+1
while x[f] != 'e':
v, f = decode_func[x[f]](x, f)
r.append(v)
return (r, f + 1)
def decode_dict(x, f):
r, f = {}, f+1
while x[f] != 'e':
k, f = decode_string(x, f)
r[k], f = decode_func[x[f]](x, f)
return (r, f + 1)
decode_func = {}
decode_func['l'] = decode_list
decode_func['d'] = decode_dict
decode_func['i'] = decode_int
decode_func['0'] = decode_string
decode_func['1'] = decode_string
decode_func['2'] = decode_string
decode_func['3'] = decode_string
decode_func['4'] = decode_string
decode_func['5'] = decode_string
decode_func['6'] = decode_string
decode_func['7'] = decode_string
decode_func['8'] = decode_string
decode_func['9'] = decode_string
def bdecode(x):
try:
r, l = decode_func[x[0]](x, 0)
except (IndexError, KeyError, ValueError):
raise BTFailure("not a valid bencoded string")
if l != len(x):
raise BTFailure("invalid bencoded value (data after valid prefix)")
return r
from types import StringType, IntType, LongType, DictType, ListType, TupleType
class Bencached(object):
__slots__ = ['bencoded']
def __init__(self, s):
self.bencoded = s
def encode_bencached(x,r):
r.append(x.bencoded)
def encode_int(x, r):
r.extend(('i', str(x), 'e'))
def encode_bool(x, r):
if x:
encode_int(1, r)
else:
encode_int(0, r)
def encode_string(x, r):
r.extend((str(len(x)), ':', x))
def encode_list(x, r):
r.append('l')
for i in x:
encode_func[type(i)](i, r)
r.append('e')
def encode_dict(x,r):
r.append('d')
ilist = x.items()
ilist.sort()
for k, v in ilist:
r.extend((str(len(k)), ':', k))
encode_func[type(v)](v, r)
r.append('e')
encode_func = {}
encode_func[Bencached] = encode_bencached
encode_func[IntType] = encode_int
encode_func[LongType] = encode_int
encode_func[StringType] = encode_string
encode_func[ListType] = encode_list
encode_func[TupleType] = encode_list
encode_func[DictType] = encode_dict
try:
from types import BooleanType
encode_func[BooleanType] = encode_bool
except ImportError:
pass
def bencode_proc(result, x):
t = type(x)
if t == str:
result.extend((str_to_bytes(str(len(x))), b':', str_to_bytes(x)))
elif t == bytes:
result.extend((str_to_bytes(str(len(x))), b':', x))
elif t == int:
result.extend((b'i', str_to_bytes(str(x)), b'e'))
elif t == dict:
result.append(b'd')
for k, v in sorted(x.items()):
bencode_proc(result, k)
bencode_proc(result, v)
result.append(b'e')
elif t == list:
result.append(b'l')
for item in x:
bencode_proc(result, item)
result.append(b'e')
def bencode(x):
r = []
encode_func[type(x)](x, r)
return ''.join(r)
result = []
bencode_proc(result, x)
return b''.join(result)
# Decoding functions ##############################################
bdecode_marker_int = ord('i')
bdecode_marker_str_min = ord('0')
bdecode_marker_str_max = ord('9')
bdecode_marker_list = ord('l')
bdecode_marker_dict = ord('d')
bdecode_marker_end = ord('e')
def bdecode_proc(msg, pos):
t = msg[pos]
if t == bdecode_marker_int:
pos += 1
pos_end = msg.index(b'e', pos)
return (int(msg[pos:pos_end]), pos_end + 1)
elif t >= bdecode_marker_str_min and t <= bdecode_marker_str_max:
sep = msg.index(b':', pos)
n = int(msg[pos:sep])
sep += 1
return (bytes(msg[sep:sep + n]), sep + n)
elif t == bdecode_marker_dict:
result = {}
pos += 1
while msg[pos] != bdecode_marker_end:
k, pos = bdecode_proc(msg, pos)
result[k], pos = bdecode_proc(msg, pos)
return (result, pos + 1)
elif t == bdecode_marker_list:
result = []
pos += 1
while msg[pos] != bdecode_marker_end:
v, pos = bdecode_proc(msg, pos)
result.append(v)
return (result, pos + 1)
def bdecode_extra(msg):
try:
result, pos = bdecode_proc(bytearray(msg), 0)
except (IndexError, KeyError, ValueError):
raise BTFailure("invalid bencoded data! %r" % msg)
return (result, pos)
def bdecode(msg):
try:
result, pos = bdecode_extra(msg)
except (IndexError, KeyError, ValueError):
raise BTFailure("invalid bencoded data: %r" % msg)
if pos != len(msg):
raise BTFailure("invalid bencoded value (data after valid prefix)")
return result
if __name__ == '__main__':
test = {b'k1': 145, b'k2': {b'sk1': list(range(10)), b'sk2': b'0'*60}}
for x in range(100000):
assert(bdecode(bencode(test)) == test)

142
crc32c.py
View File

@ -24,75 +24,81 @@ THE SOFTWARE.
# generated using pycrc (www.tty1.net/pycrc)
crc32c_table = (
0x00000000L, 0xf26b8303L, 0xe13b70f7L, 0x1350f3f4L,
0xc79a971fL, 0x35f1141cL, 0x26a1e7e8L, 0xd4ca64ebL,
0x8ad958cfL, 0x78b2dbccL, 0x6be22838L, 0x9989ab3bL,
0x4d43cfd0L, 0xbf284cd3L, 0xac78bf27L, 0x5e133c24L,
0x105ec76fL, 0xe235446cL, 0xf165b798L, 0x030e349bL,
0xd7c45070L, 0x25afd373L, 0x36ff2087L, 0xc494a384L,
0x9a879fa0L, 0x68ec1ca3L, 0x7bbcef57L, 0x89d76c54L,
0x5d1d08bfL, 0xaf768bbcL, 0xbc267848L, 0x4e4dfb4bL,
0x20bd8edeL, 0xd2d60dddL, 0xc186fe29L, 0x33ed7d2aL,
0xe72719c1L, 0x154c9ac2L, 0x061c6936L, 0xf477ea35L,
0xaa64d611L, 0x580f5512L, 0x4b5fa6e6L, 0xb93425e5L,
0x6dfe410eL, 0x9f95c20dL, 0x8cc531f9L, 0x7eaeb2faL,
0x30e349b1L, 0xc288cab2L, 0xd1d83946L, 0x23b3ba45L,
0xf779deaeL, 0x05125dadL, 0x1642ae59L, 0xe4292d5aL,
0xba3a117eL, 0x4851927dL, 0x5b016189L, 0xa96ae28aL,
0x7da08661L, 0x8fcb0562L, 0x9c9bf696L, 0x6ef07595L,
0x417b1dbcL, 0xb3109ebfL, 0xa0406d4bL, 0x522bee48L,
0x86e18aa3L, 0x748a09a0L, 0x67dafa54L, 0x95b17957L,
0xcba24573L, 0x39c9c670L, 0x2a993584L, 0xd8f2b687L,
0x0c38d26cL, 0xfe53516fL, 0xed03a29bL, 0x1f682198L,
0x5125dad3L, 0xa34e59d0L, 0xb01eaa24L, 0x42752927L,
0x96bf4dccL, 0x64d4cecfL, 0x77843d3bL, 0x85efbe38L,
0xdbfc821cL, 0x2997011fL, 0x3ac7f2ebL, 0xc8ac71e8L,
0x1c661503L, 0xee0d9600L, 0xfd5d65f4L, 0x0f36e6f7L,
0x61c69362L, 0x93ad1061L, 0x80fde395L, 0x72966096L,
0xa65c047dL, 0x5437877eL, 0x4767748aL, 0xb50cf789L,
0xeb1fcbadL, 0x197448aeL, 0x0a24bb5aL, 0xf84f3859L,
0x2c855cb2L, 0xdeeedfb1L, 0xcdbe2c45L, 0x3fd5af46L,
0x7198540dL, 0x83f3d70eL, 0x90a324faL, 0x62c8a7f9L,
0xb602c312L, 0x44694011L, 0x5739b3e5L, 0xa55230e6L,
0xfb410cc2L, 0x092a8fc1L, 0x1a7a7c35L, 0xe811ff36L,
0x3cdb9bddL, 0xceb018deL, 0xdde0eb2aL, 0x2f8b6829L,
0x82f63b78L, 0x709db87bL, 0x63cd4b8fL, 0x91a6c88cL,
0x456cac67L, 0xb7072f64L, 0xa457dc90L, 0x563c5f93L,
0x082f63b7L, 0xfa44e0b4L, 0xe9141340L, 0x1b7f9043L,
0xcfb5f4a8L, 0x3dde77abL, 0x2e8e845fL, 0xdce5075cL,
0x92a8fc17L, 0x60c37f14L, 0x73938ce0L, 0x81f80fe3L,
0x55326b08L, 0xa759e80bL, 0xb4091bffL, 0x466298fcL,
0x1871a4d8L, 0xea1a27dbL, 0xf94ad42fL, 0x0b21572cL,
0xdfeb33c7L, 0x2d80b0c4L, 0x3ed04330L, 0xccbbc033L,
0xa24bb5a6L, 0x502036a5L, 0x4370c551L, 0xb11b4652L,
0x65d122b9L, 0x97baa1baL, 0x84ea524eL, 0x7681d14dL,
0x2892ed69L, 0xdaf96e6aL, 0xc9a99d9eL, 0x3bc21e9dL,
0xef087a76L, 0x1d63f975L, 0x0e330a81L, 0xfc588982L,
0xb21572c9L, 0x407ef1caL, 0x532e023eL, 0xa145813dL,
0x758fe5d6L, 0x87e466d5L, 0x94b49521L, 0x66df1622L,
0x38cc2a06L, 0xcaa7a905L, 0xd9f75af1L, 0x2b9cd9f2L,
0xff56bd19L, 0x0d3d3e1aL, 0x1e6dcdeeL, 0xec064eedL,
0xc38d26c4L, 0x31e6a5c7L, 0x22b65633L, 0xd0ddd530L,
0x0417b1dbL, 0xf67c32d8L, 0xe52cc12cL, 0x1747422fL,
0x49547e0bL, 0xbb3ffd08L, 0xa86f0efcL, 0x5a048dffL,
0x8ecee914L, 0x7ca56a17L, 0x6ff599e3L, 0x9d9e1ae0L,
0xd3d3e1abL, 0x21b862a8L, 0x32e8915cL, 0xc083125fL,
0x144976b4L, 0xe622f5b7L, 0xf5720643L, 0x07198540L,
0x590ab964L, 0xab613a67L, 0xb831c993L, 0x4a5a4a90L,
0x9e902e7bL, 0x6cfbad78L, 0x7fab5e8cL, 0x8dc0dd8fL,
0xe330a81aL, 0x115b2b19L, 0x020bd8edL, 0xf0605beeL,
0x24aa3f05L, 0xd6c1bc06L, 0xc5914ff2L, 0x37faccf1L,
0x69e9f0d5L, 0x9b8273d6L, 0x88d28022L, 0x7ab90321L,
0xae7367caL, 0x5c18e4c9L, 0x4f48173dL, 0xbd23943eL,
0xf36e6f75L, 0x0105ec76L, 0x12551f82L, 0xe03e9c81L,
0x34f4f86aL, 0xc69f7b69L, 0xd5cf889dL, 0x27a40b9eL,
0x79b737baL, 0x8bdcb4b9L, 0x988c474dL, 0x6ae7c44eL,
0xbe2da0a5L, 0x4c4623a6L, 0x5f16d052L, 0xad7d5351L,
0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4,
0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb,
0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b,
0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24,
0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b,
0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384,
0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54,
0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b,
0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a,
0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35,
0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5,
0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa,
0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45,
0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a,
0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a,
0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595,
0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48,
0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957,
0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687,
0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198,
0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927,
0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38,
0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8,
0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7,
0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096,
0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789,
0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859,
0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46,
0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9,
0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6,
0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36,
0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829,
0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c,
0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93,
0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043,
0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c,
0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3,
0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc,
0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c,
0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033,
0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652,
0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d,
0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d,
0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982,
0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d,
0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622,
0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2,
0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed,
0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530,
0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f,
0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff,
0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0,
0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f,
0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540,
0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90,
0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f,
0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee,
0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1,
0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321,
0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e,
0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81,
0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e,
0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e,
0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351,
)
def crc32c(data):
""" return CRC32C checksum """
crc = 0xffffffffL
for byte in map(ord, data):
crc = (crc32c_table[(crc ^ byte) & 0xff] ^ (crc >> 8)) & 0xffffffffL
return (crc & 0xffffffffL) ^ 0xffffffffL
crc = 0xffffffff
for byte in data:
crc = (crc32c_table[(crc ^ byte) & 0xff] ^ (crc >> 8)) & 0xffffffff
return (crc & 0xffffffff) ^ 0xffffffff
if __name__ == '__main__':
import logging
logging.basicConfig()
log = logging.getLogger()
log.critical(crc32c(bytearray(b'1')))

382
dht.py
View File

@ -22,27 +22,30 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import os, time, socket, hashlib, hmac, threading, logging
import os, time, socket, hashlib, hmac, threading, logging, random, inspect
from bencode import bencode, bdecode
from utils import encode_int, encode_ip, encode_connection, encode_nodes, AsyncTimeout
from utils import decode_int, decode_ip, decode_connection, decode_nodes, start_thread
from utils import encode_uint32, encode_ip, encode_connection, encode_nodes, AsyncTimeout
from utils import decode_uint32, decode_ip, decode_connection, decode_nodes, start_thread, ThreadManager
from krpc import KRPCPeer, KRPCError
# BEP #0042 - prefix is based on ip and last byte of the node id - 21 most significant bits must match
def bep42_prefix(ip, rand_char, rand_rest = '\x00'): # rand_rest determines the last (random) 3 bits
# * ip = ip address in string format eg. "127.0.0.1"
def bep42_prefix(ip, crc32_salt, first_node_bits): # first_node_bits determines the last 3 bits
from crc32c import crc32c
ip = decode_int(encode_ip(ip))
value = crc32c(encode_int((ip & 0x030f3fff) | ((ord(rand_char) & 0x7) << 29)))
return (value & 0xfffff800) | ((ord(rand_rest) << 8) & 0x00000700)
ip_asint = decode_uint32(encode_ip(ip))
value = crc32c(bytearray(encode_uint32((ip_asint & 0x030f3fff) | ((crc32_salt & 0x7) << 29))))
return (value & 0xfffff800) | ((first_node_bits << 8) & 0x00000700)
def valid_id(node_id, connection):
vprefix = bep42_prefix(connection[0], node_id[-1])
return (((vprefix ^ decode_int(node_id[:4])) & 0xfffff800) == 0)
def strxor(a, b):
assert(len(a) == len(b))
return int(a.encode('hex'), 16) ^ int(b.encode('hex'), 16)
node_id = bytearray(node_id)
vprefix = bep42_prefix(connection[0], node_id[-1], 0)
return (((vprefix ^ decode_uint32(node_id[:4])) & 0xfffff800) == 0)
def decode_id(node_id):
try: # python 3
return int.from_bytes(node_id, byteorder='big')
except:
return int(node_id.encode('hex'), 16)
class DHT_Node(object):
def __init__(self, connection, id, version = None):
@ -57,43 +60,89 @@ class DHT_Node(object):
def set_id(self, id):
self.id = id
self.id_cmp = int(id.encode('hex'), 16)
self.id_cmp = decode_id(id)
def __repr__(self):
return '%s %15s %5d %20s %5s %.2f' % (self.id.encode('hex'), self.connection[0], self.connection[1],
return 'id:%s con:%15s:%-5d v:%20s c:%5s last:%.2f' % (hex(self.id_cmp), self.connection[0], self.connection[1],
repr(self.version), valid_id(self.id, self.connection), time.time() - self.last_ping)
# Trivial node list implementation
class DHT_Router(object):
def __init__(self, name):
def __init__(self, name, user_setup = {}):
setup = {'report_t': 10, 'limit_t': 30, 'limit_N': 2000, 'redeem_t': 300, 'redeem_frac': 0.05}
setup.update(user_setup)
self._log = logging.getLogger(self.__class__.__name__ + '.%s' % name)
# This is our routing table.
# This is our (trivial) routing table.
self._nodes = {}
self._nodes_lock = threading.Lock()
self._nodes_lock = threading.RLock()
self._nodes_protected = set()
self._connections_bad = set()
# Start maintainance threads
self._threads = ThreadManager(self._log.getChild('maintainance'))
self.shutdown = self._threads.shutdown
# - Report status of routing table
def _show_status():
with self._nodes_lock:
self._log.info('Routing table contains %d ids with %d nodes (%d bad, %s protected)' %\
(len(self._nodes), sum(map(len, self._nodes.values())),
len(self._connections_bad), len(self._nodes_protected)))
if self._log.isEnabledFor(logging.DEBUG):
for node in self.get_nodes():
self._log.debug('\t%r' % node)
self._threads.start_continuous_thread(_show_status, thread_interval = setup['report_t'], thread_waitfirst = True)
# - Limit number of active nodes
def _limit(maxN):
self._log.debug('Starting limitation of nodes')
N = len(self.get_nodes())
if N > maxN:
for node in self.get_nodes(N - maxN,
expression = lambda n: n.connection not in self._connections_bad,
sorter = lambda x: random.random()):
self.remove_node(node, force = True)
self._threads.start_continuous_thread(_limit, thread_interval = setup['limit_t'], maxN = setup['limit_N'], thread_waitfirst = True)
# - Redeem random nodes from the blacklist
def _redeem_connections(fraction):
self._log.debug('Starting redemption of blacklisted nodes')
remove = int(fraction * len(self._connections_bad))
with self._nodes_lock:
while self._connections_bad and (remove > 0):
self._connections_bad.pop()
remove -= 1
self._threads.start_continuous_thread(_redeem_connections, thread_interval = setup['redeem_t'], fraction = setup['redeem_frac'], thread_waitfirst = True)
def protect_nodes(self, node_id_list):
self._log.info('protect %s' % repr(sorted(node_id_list)))
with self._nodes_lock:
self._nodes_protected.update(node_id_list)
def good_node(self, node):
with self._nodes_lock:
node.attempt = 0
def remove_node(self, node, force = False):
with self._nodes_lock:
node.attempt += 1
if node.id in self._nodes:
if force or ((node.id not in self._nodes_protected) and (node.attempt > 2)):
max_attempts = 2
if valid_id(node.id, node.connection):
max_attempts = 5
if force or ((node.id not in self._nodes_protected) and (node.attempt > max_attempts)):
if not force:
self._connections_bad.add(node.connection)
self._nodes[node.id] = filter(lambda n: n.connection != node.connection, self._nodes[node.id])
def is_not_removed_node(n):
return n.connection != node.connection
self._nodes[node.id] = list(filter(is_not_removed_node, self._nodes[node.id]))
if not self._nodes[node.id]:
self._nodes.pop(node.id)
def register_node(self, node_connection, node_id, node_version = None):
with self._nodes_lock:
if node_connection in self._connections_bad:
@ -112,7 +161,7 @@ class DHT_Router(object):
return node
# Return nodes matching a filter expression
def get_nodes(self, N = None, expression = lambda n: True, sorter = None):
def get_nodes(self, N = None, expression = lambda n: True, sorter = lambda n: n.id_cmp):
if len(self._nodes) == 0:
raise RuntimeError('No nodes in routing table!')
result = []
@ -124,151 +173,140 @@ class DHT_Router(object):
return result
return result[:N]
def redeem_connections(self, fraction = 0.05):
remove = int(fraction * len(self._connections_bad))
with self._nodes_lock:
while self._connections_bad and (remove > 0):
self._connections_bad.pop()
remove -= 1
def show_status(self):
with self._nodes_lock:
self._log.info('Routing table contains %d nodes (%d blacklisted, %s protected)' %\
(len(self._nodes), len(self._connections_bad), len(self._nodes_protected)))
if self._log.isEnabledFor(logging.DEBUG):
for node in self.get_nodes():
self._log.debug('\t%r' % node)
class DHT(object):
def __init__(self, listen_connection, bootstrap_connection = ('router.bittorrent.com', 6881), user_setup = {}):
def __init__(self, listen_connection, bootstrap_connection = ('router.bittorrent.com', 6881),
user_setup = {}, user_router = None):
""" Start DHT peer on given (host, port) and bootstrap connection to the DHT """
setup = {'report_t': 10, 'check_t': 30, 'check_N': 10, 'discover_t': 180, 'redeem_t': 300}
setup = {'discover_t': 180, 'check_t': 30, 'check_N': 10}
setup.update(user_setup)
self._log = logging.getLogger(self.__class__.__name__ + '.%s.%d' % listen_connection)
self._log.info('Starting DHT node with bootstrap connection %s:%d' % bootstrap_connection)
listen_connection = (socket.gethostbyname(listen_connection[0]), listen_connection[1])
# Generate key for token generation
self._token_key = os.urandom(20)
# Start KRPC server process and Routing table
self._krpc = KRPCPeer(listen_connection, self._handle_query, cleanup_interval = 1)
self._nodes = DHT_Router('%s.%d' % listen_connection)
self._krpc = KRPCPeer(listen_connection, self._handle_query)
if not user_router:
user_router = DHT_Router('%s.%d' % listen_connection, setup)
self._nodes = user_router
self._node = DHT_Node(listen_connection, os.urandom(20))
self._node_lock = threading.RLock()
# Start bootstrap process
try:
tmp = self.ping(bootstrap_connection, sender_id = self._node.id).get_result(timeout = 5)
tmp = self.ping(bootstrap_connection, sender_id = self._node.id).get_result(timeout = 1)
except Exception:
tmp = {'ip': encode_connection(listen_connection), 'r': {'id': self._node.id}}
self._node.connection = decode_connection(tmp['ip'])
self._bootstrap_node = self._nodes.register_node(bootstrap_connection, tmp['r']['id'])
raise
tmp = {b'ip': encode_connection(listen_connection), b'r': {b'id': self._node.id}}
self._node.connection = decode_connection(tmp[b'ip'])
self._bootstrap_node = self._nodes.register_node(bootstrap_connection, tmp[b'r'][b'id'])
# BEP #0042 Enable security extension
self._node.set_id(encode_int(bep42_prefix(self._node.connection[0], self._node.id[-1], self._node.id[0]))[:3] + self._node.id[3:])
local_id = bytearray(self._node.id)
bep42_value = encode_uint32(bep42_prefix(self._node.connection[0], local_id[-1], local_id[0]))
self._node.set_id(bep42_value[:3] + self._node.id[3:])
assert(valid_id(self._node.id, self._node.connection))
self._nodes.protect_nodes([self._node.id])
# Start maintainance threads
self._shutdown_event = threading.Event()
# Report status of routing table
self._thread_report = start_thread(self._maintainance_task, self._nodes.show_status,
interval = setup['report_t'])
self._threads = ThreadManager(self._log.getChild('maintainance'))
# Periodically ping nodes in the routing table
def _check_nodes(N):
check_nodes = list(self._nodes.get_nodes(N, expression = lambda n: (time.time() - n.last_ping > 15*60)))
def _check_nodes(N, last_ping = 15 * 60, timeout = 5):
def get_unpinged(n):
return time.time() - n.last_ping > last_ping
check_nodes = list(self._nodes.get_nodes(N, expression = get_unpinged))
if not check_nodes:
return
self._log.info('Starting cleanup of known nodes')
self._log.debug('Starting cleanup of known nodes')
node_result_list = []
for node in check_nodes:
node.last_ping = time.time()
node_result_list.append((node, node.id, self.ping(node.connection, self._node.id)))
t_end = time.time() + 5
t_end = time.time() + timeout
for (node, node_id, async_result) in node_result_list:
result = self._eval_dht_response(node, async_result, timeout = max(0, t_end - time.time()))
if node.id != result.get('id'):
if result and (node.id != result.get(b'id')): # remove nodes with changing identities
self._nodes.remove_node(node, force = True)
self._thread_check = start_thread(self._maintainance_task, _check_nodes,
interval = setup['check_t'], N = setup['check_N'])
# Redeem random nodes from the blacklist
def _redeem():
self._log.info('Starting redemption of blacklisted nodes')
self._nodes.redeem_connections()
self._thread_redeem = start_thread(self._maintainance_task, _redeem, interval = setup['redeem_t'])
self._threads.start_continuous_thread(_check_nodes, thread_interval = setup['check_t'], N = setup['check_N'])
# Try to discover a random node to populate routing table
def _discover_nodes():
self._log.info('Starting discovery of random node')
for idx, entry in enumerate(self.dht_find_node(os.urandom(20))):
self._log.debug('Starting discovery of random node')
for idx, entry in enumerate(self.dht_find_node(os.urandom(20), timeout = 1)):
if idx > 10:
break
self._thread_discovery = start_thread(self._maintainance_task, _discover_nodes,
interval = setup['discover_t'])
self._threads.start_continuous_thread(_discover_nodes, thread_interval = setup['discover_t'])
def get_external_ip(self):
def get_external_connection(self):
return self._node.connection
def shutdown(self):
""" This function allows to cleanly shutdown the DHT. """
self._log.info('shutting down DHT')
self._shutdown_event.set() # Trigger shutdown of maintainance threads
while True in map(threading.Thread.is_alive, [self._thread_report, self._thread_check,
self._thread_redeem, self._thread_discovery]):
time.sleep(0.1)
self._threads.shutdown() # Trigger shutdown of maintainance threads
self._krpc.shutdown() # Stop listening for incoming connections
# Maintainance task
def _maintainance_task(self, function, interval, **kwargs):
while interval > 0:
try:
function(**kwargs)
except Exception:
self._log.exception('Exception in DHT maintenance thread')
if self._shutdown_event.wait(interval):
return
self._nodes.shutdown()
self._threads.join() # Trigger shutdown of maintainance threads
# Handle remote queries
_reply_handler = {}
def _handle_query(self, send_krpc_reply, rec, source_connection):
if self._log.isEnabledFor(logging.DEBUG):
self._log.debug('handling query from %r: %r' % (source_connection, rec))
kwargs = rec['a']
if 'id' in kwargs:
self._nodes.register_node(source_connection, kwargs['id'], rec.get('v'))
query = rec['q']
if query in self._reply_handler:
send_dht_reply = lambda **kwargs: send_krpc_reply(kwargs,
try:
remote_args_dict = rec[b'a']
if b'id' in remote_args_dict:
self._nodes.register_node(source_connection, remote_args_dict[b'id'], rec.get(b'v'))
query = rec[b'q']
callback = self._reply_handler[query]
callback_kwargs = {}
for arg in inspect.getargspec(callback).args[2:]:
arg_bytes = arg.encode('ascii')
if arg_bytes in remote_args_dict:
callback_kwargs[arg] = remote_args_dict[arg_bytes]
def send_dht_reply(**kwargs):
# BEP #0042 - require ip field in answer
{'ip': encode_connection(source_connection)})
return send_krpc_reply(kwargs, {b'ip': encode_connection(source_connection)})
send_dht_reply.connection = source_connection
self._reply_handler[query](self, send_dht_reply, **kwargs)
else:
self._log.error('Unknown request in query %r' % rec)
callback(self, send_dht_reply, **callback_kwargs)
except Exception:
self._log.exception('Error while processing request %r' % rec)
# Evaluate async KRPC result and notify the routing table about failures
def _eval_dht_response(self, node, async_result, timeout):
try:
result = async_result.get_result(timeout)
node.version = result.get('v', node.version)
node.version = result.get(b'v', node.version)
self._nodes.good_node(node)
return result['r']
return result[b'r']
except AsyncTimeout: # The node did not reply
if self._log.isEnabledFor(logging.DEBUG):
self._log.debug('KRPC timeout %r' % node)
except KRPCError: # Some other error occured
self._log.exception('KRPC Error %r' % node)
if self._log.isEnabledFor(logging.INFO):
self._log.exception('KRPC Error %r' % node)
self._nodes.remove_node(node)
async_result.discard_result()
return {}
# Iterate KRPC function on closest nodes - query_fun(connection, id, search_value)
def _iter_krpc_search(self, query_fun, process_fun, search_value, timeout, retries):
id_cmp = int(search_value.encode('hex'), 16)
id_cmp = decode_id(search_value)
(returned, used_connections, discovered_nodes) = (set(), {}, set())
while True:
blacklist_connections = filter(lambda c: used_connections[c] > retries, used_connections)
discovered_nodes = set(filter(lambda n: n and (n.connection not in blacklist_connections), discovered_nodes))
close_nodes = set(self._nodes.get_nodes(N = 20,
expression = lambda n: n.connection not in blacklist_connections,
sorter = lambda n: n.id_cmp ^ id_cmp))
while not self._threads.shutdown_in_progress():
def above_retries(c):
return used_connections[c] > retries
blacklist_connections = set(filter(above_retries, used_connections))
def valid_node(n):
return n and (n.connection not in blacklist_connections)
discovered_nodes = set(filter(valid_node, discovered_nodes))
def not_blacklisted(n):
return n.connection not in blacklist_connections
def sort_by_id(n):
return n.id_cmp ^ id_cmp
close_nodes = set(self._nodes.get_nodes(N = 20, expression = not_blacklisted, sorter = sort_by_id))
if not close_nodes.union(discovered_nodes):
break
@ -285,12 +323,15 @@ class DHT(object):
node_result_list.append((node, async_result))
used_connections[node.connection] = used_connections.get(node.connection, 0) + 1
t_end = time.time() + timeout
for (node, async_result) in node_result_list: # sequentially retrieve results
if self._threads.shutdown_in_progress():
break
result = self._eval_dht_response(node, async_result, timeout = max(0, t_end - time.time()))
with self._node_lock:
node.pending -= 1
for node_id, node_connection in decode_nodes(result.get('nodes', '')):
for node_id, node_connection in decode_nodes(result.get(b'nodes', b'')):
discovered_nodes.add(self._nodes.register_node(node_connection, node_id))
for tmp in process_fun(node, result):
if tmp not in returned:
@ -306,97 +347,107 @@ class DHT(object):
# ping methods
# (sync method)
def dht_ping(self, connection, timeout = 1):
def dht_ping(self, connection, timeout = 5):
try:
result = self.ping(connection, self._node.id).get_result(timeout)
if result.get('r', {}).get('id'):
self._nodes.register_node(connection, result['r']['id'], result.get('v'))
return result.get('r', {})
if result.get(b'r', {}).get(b'id'):
self._nodes.register_node(connection, result[b'r'][b'id'], result.get(b'v'))
return result.get(b'r', {})
except (AsyncTimeout, KRPCError):
pass
# (verbatim, async KRPC method)
def ping(self, target_connection, sender_id):
return self._krpc.send_krpc_query(target_connection, 'ping', id = sender_id)
return self._krpc.send_krpc_query(target_connection, b'ping', id = sender_id)
# (reply method)
def _ping(self, send_krpc_reply, id, **kwargs):
def _ping(self, send_krpc_reply, id):
send_krpc_reply(id = self._node.id)
_reply_handler['ping'] = _ping
_reply_handler[b'ping'] = _ping
# find_node methods
# (sync method, iterating on close nodes)
def dht_find_node(self, search_id):
def dht_find_node(self, search_id, timeout = 5, retries = 2):
def process_find_node(node, result):
for node_id, node_connection in decode_nodes(result.get('nodes', '')):
for node_id, node_connection in decode_nodes(result.get(b'nodes', b'')):
if node_id == search_id:
yield node_connection
return self._iter_krpc_search(self.find_node, process_find_node, search_id, timeout = 5, retries = 2)
return self._iter_krpc_search(self.find_node, process_find_node, search_id, timeout, retries)
# (verbatim, async KRPC method)
def find_node(self, target_connection, sender_id, search_id):
return self._krpc.send_krpc_query(target_connection, 'find_node', id = sender_id, target = search_id)
return self._krpc.send_krpc_query(target_connection, b'find_node', id = sender_id, target = search_id)
# (reply method)
def _find_node(self, send_krpc_reply, id, target, **kwargs):
id_cmp = int(id.encode('hex'), 16)
def _find_node(self, send_krpc_reply, id, target):
id_cmp = decode_id(id)
def select_valid(n):
return valid_id(n.id, n.connection)
def sort_by_id(n):
return n.id_cmp ^ id_cmp
send_krpc_reply(id = self._node.id, nodes = encode_nodes(self._nodes.get_nodes(N = 20,
expression = lambda n: valid_id(n.id, n.connection),
sorter = lambda n: n.id_cmp ^ id_cmp)))
_reply_handler['find_node'] = _find_node
expression = select_valid, sorter = sort_by_id)))
_reply_handler[b'find_node'] = _find_node
# get_peers methods
# (sync method, iterating on close nodes)
def dht_get_peers(self, info_hash):
def dht_get_peers(self, info_hash, timeout = 5, retries = 2):
def process_get_peers(node, result):
if result.get('token'):
node.tokens[info_hash] = result['token'] # store token for subsequent announce_peer
for node_connection in map(decode_connection, result.get('values', '')):
if result.get(b'token'):
node.tokens[info_hash] = result[b'token'] # store token for subsequent announce_peer
for node_connection in map(decode_connection, result.get(b'values', b'')):
yield node_connection
return self._iter_krpc_search(self.get_peers, process_get_peers, info_hash, timeout = 5, retries = 2)
return self._iter_krpc_search(self.get_peers, process_get_peers, info_hash, timeout, retries)
# (verbatim, async KRPC method)
def get_peers(self, target_connection, sender_id, info_hash):
return self._krpc.send_krpc_query(target_connection, 'get_peers', id = sender_id, info_hash = info_hash)
return self._krpc.send_krpc_query(target_connection, b'get_peers', id = sender_id, info_hash = info_hash)
# (reply method)
def _get_peers(self, send_krpc_reply, id, info_hash, **kwargs):
token = hmac.new(self._token_key, send_krpc_reply.connection[0], hashlib.sha1).digest()
id_cmp = int(id.encode('hex'), 16)
reply_args = {'nodes': encode_nodes(self._nodes.get_nodes(N = 8,
expression = lambda n: valid_id(n.id, n.connection),
sorter = lambda n: n.id_cmp ^ id_cmp))}
def _get_peers(self, send_krpc_reply, id, info_hash):
token = hmac.new(self._token_key, encode_ip(send_krpc_reply.connection[0]), hashlib.sha1).digest()
id_cmp = decode_id(id)
def select_valid(n):
return valid_id(n.id, n.connection)
def sort_by_id(n):
return n.id_cmp ^ id_cmp
reply_args = {'nodes': encode_nodes(self._nodes.get_nodes(N = 8, expression = select_valid, sorter = sort_by_id))}
if self._node.values.get(info_hash):
reply_args['values'] = map(encode_connection, self._node.values[info_hash])
reply_args['values'] = list(map(encode_connection, self._node.values[info_hash]))
send_krpc_reply(id = self._node.id, token = token, **reply_args)
_reply_handler['get_peers'] = _get_peers
_reply_handler[b'get_peers'] = _get_peers
# announce_peer methods
# (sync method, announcing to all nodes giving tokens)
def dht_announce_peer(self, info_hash):
for node in self._nodes.get_nodes(expression = lambda n: info_hash in n.tokens):
def dht_announce_peer(self, info_hash, implied_port = 1):
def has_info_hash_token(node):
return info_hash in node.tokens
for node in self._nodes.get_nodes(expression = has_info_hash_token):
yield self.announce_peer(node.connection, self._node.id, info_hash, self._node.connection[1],
node.tokens[info_hash], implied_port = 1)
node.tokens[info_hash], implied_port = implied_port)
# (verbatim, async KRPC method)
def announce_peer(self, target_connection, sender_id, info_hash, port, token, implied_port = None):
req = {'id': sender_id, 'info_hash': info_hash, 'port': port, 'token': token}
if implied_port != None: # (optional) "1": port not reliable - remote should use source port
req['implied_port'] = implied_port
return self._krpc.send_krpc_query(target_connection, 'announce_peer', **req)
return self._krpc.send_krpc_query(target_connection, b'announce_peer', **req)
# (reply method)
def _announce_peer(self, send_krpc_reply, id, info_hash, port, token, implied_port = None, **kwargs):
local_token = hmac.new(self._token_key, send_krpc_reply.connection[0], hashlib.sha1).digest()
def _announce_peer(self, send_krpc_reply, id, info_hash, port, token, implied_port = None):
local_token = hmac.new(self._token_key, encode_ip(send_krpc_reply.connection[0]), hashlib.sha1).digest()
if (local_token == token) and valid_id(id, send_krpc_reply.connection): # Validate token and ID
if implied_port:
port = send_krpc_reply.connection[1]
self._node.values.setdefault(info_hash, []).append((send_krpc_reply.connection[0], port))
send_krpc_reply(id = self._node.id)
_reply_handler['announce_peer'] = _announce_peer
_reply_handler[b'announce_peer'] = _announce_peer
if __name__ == '__main__':
logging.basicConfig()
# logging.getLogger().setLevel(logging.INFO)
# logging.getLogger('DHT').setLevel(logging.INFO)
logging.getLogger('DHT_Router').setLevel(logging.DEBUG)
# logging.getLogger('KRPCPeer').setLevel(logging.INFO)
log = logging.getLogger()
log.setLevel(logging.INFO)
logging.getLogger('DHT').setLevel(logging.INFO)
logging.getLogger('DHT_Router').setLevel(logging.ERROR)
logging.getLogger('KRPCPeer').setLevel(logging.ERROR)
logging.getLogger('KRPCPeer.local').setLevel(logging.ERROR)
logging.getLogger('KRPCPeer.remote').setLevel(logging.ERROR)
# Create a DHT node
setup = {'report_t': 5, 'check_t': 2, 'check_N': 10, 'discover_t': 3}
# Create a DHT swarm
setup = {}
bootstrap_connection = ('localhost', 10001)
# bootstrap_connection = ('router.bittorrent.com', 6881)
dht1 = DHT(('0.0.0.0', 10001), bootstrap_connection, setup)
@ -406,29 +457,30 @@ if __name__ == '__main__':
dht5 = DHT(('0.0.0.0', 10005), ('localhost', 10003), setup)
dht6 = DHT(('0.0.0.0', 10006), ('localhost', 10005), setup)
print '\nping\n' + '=' * 20 # Ping bootstrap node
print dht1.dht_ping(bootstrap_connection)
print dht6.dht_ping(bootstrap_connection)
log.critical('starting "ping" test')
log.critical('ping: dht1 -> bootstrap = %r' % dht1.dht_ping(bootstrap_connection))
log.critical('ping: dht6 -> bootstrap = %r' % dht6.dht_ping(bootstrap_connection))
print '\nfind_node\n' + '=' * 20 # Search myself
for node in dht3.dht_find_node(dht1._node.id):
print '->', node
log.critical('starting "find_node" test')
for idx, node in enumerate(dht3.dht_find_node(dht1._node.id)):
log.critical('find_node: dht3 -> id(dht1) result #%d: %s:%d' % (idx, node[0], node[1]))
if idx > 10:
break
print '\nget_peers\n' + '=' * 20 # Search Ubuntu 14.04 info hash
info_hash = 'cb84ccc10f296df72d6c40ba7a07c178a4323a14'.decode('hex')
for peer in dht5.dht_get_peers(info_hash):
print '->', peer
import binascii
info_hash = binascii.unhexlify('cb84ccc10f296df72d6c40ba7a07c178a4323a14') # Ubuntu 14.04 info hash
print '\nannounce_peer\n' + '=' * 20 # Announce availability of info hash at dht5
print dht5.dht_announce_peer(info_hash)
log.critical('starting "get_peers" test')
for idx, peer in enumerate(dht5.dht_get_peers(info_hash)):
log.critical('get_peers: dht5 -> info_hash result #%d: %r' % (idx, peer))
print '\nget_peers\n' + '=' * 20
for peer in dht3.dht_get_peers(info_hash):
print '->', peer
log.critical('starting "announce_peer" test')
for idx, async_result in enumerate(dht5.dht_announce_peer(info_hash)):
log.critical('announce_peer: dht2 -> close_nodes(info_hash) #%d: %r' % (idx, async_result.get_result(1)))
print 'done...'
time.sleep(5*60)
dht1.shutdown()
dht6.shutdown()
print 'shutdown complete'
time.sleep(60*60)
log.critical('starting "get_peers" test')
for idx, peer in enumerate(dht1.dht_get_peers(info_hash)):
log.critical('get_peers: dht1 -> info_hash result #%d: %r' % (idx, peer))
for dht in [dht1, dht2, dht3, dht4, dht5, dht6]:
dht.shutdown()

183
krpc.py
View File

@ -22,11 +22,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import time, socket, select, threading, logging
from bencode import bencode, bdecode
from utils import client_version, start_thread, AsyncResult, AsyncTimeout, encode_int
import socket, threading, logging
from bencode import bencode, bdecode, BTFailure
from utils import client_version, AsyncResult, AsyncTimeout, encode_uint64, UDPSocket, ThreadManager
krpc_version = client_version[0] + chr(client_version[1]) + chr(client_version[2])
krpc_version = bytes(client_version[0] + bytearray([client_version[1], client_version[2]]))
class KRPCError(RuntimeError):
pass
@ -39,80 +39,20 @@ class KRPCPeer(object):
send_krpc_response(**kwargs) is a function to send a reply,
rec contains the dictionary with the incoming message.
"""
self._log = logging.getLogger(self.__class__.__name__)
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._sock.setblocking(0)
self._sock.bind(connection)
self._log = logging.getLogger(self.__class__.__name__ + '.%s:%d' % connection)
self._log_msg = self._log.getChild('msg') # message handling
self._log_local = self._log.getChild('local') # local queries
self._log_remote = self._log.getChild('remote') # remote queries
self._sock = UDPSocket(connection)
self._transaction = {}
self._transaction_id = 0
self._transaction_lock = threading.Lock()
self._handle_query = handle_query
self._shutdown_flag = False
self._listen_thread = start_thread(self._listen)
self._cleanup_thread = start_thread(self._cleanup_transactions,
timeout = cleanup_timeout, interval = cleanup_interval)
def shutdown(self):
""" This function allows to cleanly shutdown the KRPCPeer. """
self._shutdown_flag = True
def _cleanup_transactions(self, timeout = 60, interval = 10):
while not self._shutdown_flag:
# Remove transactions older than 1min
with self._transaction_lock:
timeout_transactions = list(filter(lambda t: self._transaction[t].get_age() > timeout, self._transaction))
if self._log.isEnabledFor(logging.DEBUG):
self._log.debug('Transactions: %d id=%d timeout=%d' % (len(self._transaction), self._transaction_id, len(timeout_transactions)))
for t in timeout_transactions:
self._transaction.pop(t).set_result(AsyncTimeout('Transaction %r: timeout' % t))
time.sleep(interval)
def _listen(self):
while not self._shutdown_flag:
try:
if select.select([self._sock], [], [], 10)[0]:
(encoded_rec, source_connection) = self._sock.recvfrom(64*1024)
try:
rec = bdecode(encoded_rec)
except:
self._log.exception('Exception while parsing KRPC requests from %r:\n\t%r' % (source_connection, encoded_rec))
continue
if rec['y'] in ['r', 'e']: # Response / Error message
t = rec['t']
if rec['y'] == 'e':
if self._log.isEnabledFor(logging.DEBUG):
self._log.debug('KRPC error message from %r:\n\t%r' % (source_connection, rec))
rec = KRPCError('Error while processing transaction %r:\n\t%r' % (t, rec))
else:
if self._log.isEnabledFor(logging.DEBUG):
self._log.debug('KRPC answer from %r:\n\t%r' % (source_connection, rec))
with self._transaction_lock:
if self._transaction.get(t):
self._transaction.pop(t).set_result(rec, source = source_connection)
elif self._log.isEnabledFor(logging.INFO):
self._log.info('Received response from %r without associated transaction:\n%r' % (source_connection, rec))
elif rec['y'] == 'q':
if self._log.isEnabledFor(logging.DEBUG):
self._log.debug('KRPC request from %r:\n\t%r' % (source_connection, rec))
send_krpc_response = lambda message, top_level_message = {}:\
self._send_krpc_response(source_connection, rec.pop('t'), message, top_level_message)
self._handle_query(send_krpc_response, rec, source_connection)
else:
if self._log.isEnabledFor(logging.INFO):
self._log.info('Unknown type of KRPC message from %r:\n\t%r' % (source_connection, rec))
except Exception:
self._log.exception('Exception while handling KRPC requests from %r:\n\t%r' % (source_connection, rec))
self._sock.close()
def _send_krpc_response(self, source_connection, remote_transaction, message, top_level_message = {}):
with self._transaction_lock:
resp = {'y': 'r', 't': remote_transaction, 'v': krpc_version, 'r': message}
resp.update(top_level_message)
if self._log.isEnabledFor(logging.DEBUG):
self._log.debug('KRPC response to %r:\n\t%r' % (source_connection, resp))
self._sock.sendto(bencode(resp), source_connection)
self._threads = ThreadManager(self._log)
self._threads.start_continuous_thread(self._listen)
self._threads.start_continuous_thread(self._cleanup_transactions,
thread_interval = cleanup_interval, timeout = cleanup_timeout)
def send_krpc_query(self, target_connection, method, **kwargs):
""" Invoke method on the node at target_connection.
@ -124,15 +64,94 @@ class KRPCPeer(object):
with self._transaction_lock:
while True: # Generate transaction id
self._transaction_id += 1
local_transaction = encode_int(self._transaction_id).lstrip('\x00')
local_transaction = bytes(bytearray(encode_uint64(self._transaction_id)).lstrip(b'\x00'))
if local_transaction not in self._transaction:
break
req = {'y': 'q', 't': local_transaction, 'v': krpc_version, 'q': method, 'a': kwargs}
self._transaction[local_transaction] = AsyncResult(source = (method, kwargs, target_connection))
req = {b'y': b'q', b't': local_transaction, b'v': krpc_version, b'q': method, b'a': kwargs}
result = AsyncResult(source = (method, kwargs, target_connection))
if not self._threads.shutdown_in_progress():
if self._log_local.isEnabledFor(logging.INFO):
self._log_local.info('KRPC request to %r:\n\t%r' % (target_connection, req))
self._transaction[local_transaction] = result
self._sock.sendto(bencode(req), target_connection)
else:
result.set_result(AsyncTimeout('Shutdown in progress'))
return result
def shutdown(self):
""" This function allows to cleanly shutdown the KRPCPeer. """
self._threads.shutdown()
self._sock.close()
with self._transaction_lock:
for t in list(self._transaction):
self._transaction.pop(t).set_result(AsyncTimeout('Shutdown in progress'))
self._threads.join()
# Private members #################################################
def _cleanup_transactions(self, timeout):
# Remove transactions older than 1min
with self._transaction_lock:
timeout_transactions = [t for t, ar in self._transaction.items() if ar.get_age() > timeout]
if self._log.isEnabledFor(logging.DEBUG):
self._log.debug('KRPC request to %r:\n\t%r' % (target_connection, req))
self._sock.sendto(bencode(req), target_connection)
return self._transaction[local_transaction]
self._log.debug('Transactions: %d id=%d timeout=%d' % (len(self._transaction), self._transaction_id, len(timeout_transactions)))
for t in timeout_transactions:
self._transaction.pop(t).set_result(AsyncTimeout('Transaction %r: timeout' % t))
def _listen(self):
try:
recv_data = self._sock.recvfrom(timeout = 0.2)
if not recv_data:
return
(encoded_rec, source_connection) = recv_data
try:
rec = bdecode(encoded_rec)
except BTFailure:
if self._log_msg.isEnabledFor(logging.ERROR):
self._log_msg.error('Error while parsing KRPC requests from %r:\n\t%r' % (source_connection, encoded_rec))
return
except Exception:
return self._log_msg.exception('Exception while handling KRPC requests from %r:\n\t%r' % (source_connection, encoded_rec))
try:
if rec[b'y'] in [b'r', b'e']: # Response / Error message
t = rec[b't']
if rec[b'y'] == b'e':
if self._log_local.isEnabledFor(logging.ERROR):
self._log_local.error('KRPC error message from %r:\n\t%r' % (source_connection, rec))
with self._transaction_lock:
if self._transaction.get(t):
rec = KRPCError('Error while processing transaction %r:\n\t%r\n\t%r' % (t, rec, self._transaction.get(t).get_source()))
else:
rec = KRPCError('Error while processing transaction %r:\n\t%r' % (t, rec))
else:
if self._log_local.isEnabledFor(logging.INFO):
self._log_local.info('KRPC answer from %r:\n\t%r' % (source_connection, rec))
with self._transaction_lock:
if self._transaction.get(t):
self._transaction.pop(t).set_result(rec, source = source_connection)
elif self._log_local.isEnabledFor(logging.DEBUG):
self._log_local.debug('Received response from %r without associated transaction:\n%r' % (source_connection, rec))
elif rec[b'y'] == b'q':
if self._log_remote.isEnabledFor(logging.INFO):
self._log_remote.info('KRPC request from %r:\n\t%r' % (source_connection, rec))
def custom_send_krpc_response(message, top_level_message = {}):
return self._send_krpc_response(source_connection, rec.pop(b't'), message, top_level_message, self._log_remote)
self._handle_query(custom_send_krpc_response, rec, source_connection)
else:
if self._log_msg.isEnabledFor(logging.ERROR):
self._log_msg.error('Unknown type of KRPC message from %r:\n\t%r' % (source_connection, rec))
except Exception:
self._log_msg.exception('Exception while handling KRPC requests from %r:\n\t%r' % (source_connection, rec))
def _send_krpc_response(self, source_connection, remote_transaction, message, top_level_message = {}, log = None):
with self._transaction_lock:
resp = {b'y': b'r', b't': remote_transaction, b'v': krpc_version, b'r': message}
resp.update(top_level_message)
if log == None:
log = self._log_local
if log.isEnabledFor(logging.INFO):
log.info('KRPC response to %r:\n\t%r' % (source_connection, resp))
self._sock.sendto(bencode(resp), source_connection)
if __name__ == '__main__':
@ -140,5 +159,7 @@ if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
# Implement an echo message
peer = KRPCPeer(('0.0.0.0', 1111), handle_query = lambda send_krpc_response, rec, source_connection:
send_krpc_response(message = 'Hello %s!' % rec['a']['message']))
print peer.send_krpc_query(('localhost', 1111), 'echo', message = 'World').get_result(2)
send_krpc_response(message = 'Hello %s!' % rec[b'a'][b'message']))
query = peer.send_krpc_query(('localhost', 1111), 'echo', message = 'World')
logging.getLogger().critical('result = %r' % query.get_result(2))
peer.shutdown()

199
utils.py
View File

@ -22,46 +22,55 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import socket, struct, threading, time
import sys, select, socket, struct, threading, time, collections, logging
client_version = ('XK', 0, 0x01) # eXperimental Klient 0.0.1
client_version = (b'XK', 0, 0x01) # eXperimental Klient 0.0.1
encode_short = lambda value: struct.pack('!H', value)
encode_int = lambda value: struct.pack('!I', value)
encode_ip = lambda value: socket.inet_aton(value)
encode_uint16 = lambda value: struct.pack('!H', value)
encode_uint32 = lambda value: struct.pack('!I', value)
encode_uint64 = lambda value: struct.pack('!Q', value)
encode_int32 = lambda value: struct.pack('!i', value)
def encode_connection(con):
return encode_ip(con[0]) + encode_short(con[1])
return encode_ip(con[0]) + encode_uint16(con[1])
def encode_nodes(nodes):
result = ''
result = b''
for node in nodes:
result += struct.pack('20s', node.id) + encode_connection(node.connection)
result += bytes(bytearray(node.id).rjust(20, b'\0')) + encode_connection(node.connection)
return result
decode_short = lambda value: struct.unpack('!H', value)[0]
decode_int = lambda value: struct.unpack('!I', value)[0]
decode_ip = lambda value: socket.inet_ntoa(value)
decode_uint16 = lambda value: struct.unpack('!H', value)[0]
decode_uint32 = lambda value: struct.unpack('!I', value)[0]
decode_uint64 = lambda value: struct.unpack('!Q', value)[0]
def decode_connection(con):
return (decode_ip(con[0:4]), decode_short(con[4:6]))
return (decode_ip(con[0:4]), decode_uint16(con[4:6]))
def decode_nodes(nodes):
while nodes:
node_id = struct.unpack('20s', nodes[:20])[0]
node_connection = decode_connection(nodes[20:26])
yield (node_id, node_connection)
nodes = nodes[26:]
try:
while nodes:
node_id = struct.unpack('20s', nodes[:20])[0]
node_connection = decode_connection(nodes[20:26])
if node_connection[1] >= 1024: # discard invalid port numbers
yield (node_id, node_connection)
nodes = nodes[26:]
except Exception:
pass # catch malformed nodes
def start_thread(fun, *args, **kwargs):
thread = threading.Thread(target=fun, args=args, kwargs=kwargs)
thread = threading.Thread(name = repr(fun), target=fun, args=args, kwargs=kwargs)
thread.daemon = True
thread.start()
return thread
class AsyncTimeout(RuntimeError):
pass
class AsyncResult(object):
def __init__(self, source = None):
self._event = threading.Event()
@ -77,7 +86,8 @@ class AsyncResult(object):
def set_result(self, result, source = None):
self._value = result
self._source = source
if source != None:
self._source = source
self._event.set()
def has_result(self):
@ -92,3 +102,158 @@ class AsyncResult(object):
if isinstance(self._value, Exception):
raise self._value
return self._value
class ThreadManager(object):
def __init__(self, log):
self._log = log
self._threads = []
self._shutdown_event = threading.Event()
def shutdown_in_progress(self):
return self._shutdown_event.is_set()
def shutdown(self):
self._shutdown_event.set() # Trigger shutdown of threads
def join(self, timeout = 60):
self.shutdown()
for thread in self._threads:
thread.join(timeout)
def start_thread(self, name, daemon, fun, *args, **kwargs):
thread = threading.Thread(name = name, target=fun, args=args, kwargs=kwargs)
thread.daemon = daemon
thread.start()
self._threads.append(thread)
return thread
def start_continuous_thread(self, fun, thread_interval = 0, *args, **kwargs):
if thread_interval >= 0:
self.start_thread('continuous thread:' + repr(fun), False,
self._continuous_thread_wrapper, fun, thread_interval = thread_interval, *args, **kwargs)
def _continuous_thread_wrapper(self, fun, on_except = ['log', 'continue'], thread_waitfirst = False, thread_interval = 0, *args, **kwargs):
if thread_waitfirst:
self._shutdown_event.wait(thread_interval)
while not self._shutdown_event.is_set():
try:
fun(*args, **kwargs)
except Exception:
if 'log' in on_except:
self._log.exception('Exception in maintainance thread')
if 'continue' not in on_except:
return
self._shutdown_event.wait(thread_interval)
class NetworkSocket(object):
def __init__(self, name):
self._log = logging.getLogger(self.__class__.__name__).getChild(name)
self._threads = ThreadManager(self._log)
self._lock = threading.Lock()
self._send_event = threading.Event()
self._send_queue = collections.deque()
self._send_try = 0
self._recv_event = threading.Event()
self._recv_queue = collections.deque()
self._force_show_info = False
self._threads.start_continuous_thread(self._info_thread, thread_interval = 0.5)
self._threads.start_continuous_thread(self._send_thread)
self._threads.start_continuous_thread(self._recv_thread)
# Non-blocking send
def sendto(self, *args):
self._send_queue.append(args)
with self._lock: # set send flag
self._send_event.set()
# Blocking read - with timeout
def recvfrom(self, timeout = None):
result = None
if self._recv_event.wait(timeout):
if self._recv_queue:
result = self._recv_queue.pop()
with self._lock:
if not self._recv_queue and not self._threads.shutdown_in_progress():
self._recv_event.clear()
return result
def close(self):
with self._lock:
self._threads.shutdown()
self._send_event.set()
self._recv_event.set()
self._close()
self._threads.join()
# Private members #################################################
def _info_thread(self):
if (len(self._recv_queue) > 20) or (len(self._send_queue) > 20) or self._force_show_info:
if self._log.isEnabledFor(logging.DEBUG):
self._log.debug('recv: %4d, send: %4d' % (len(self._recv_queue), len(self._send_queue)))
self._force_show_info = True
if not(len(self._recv_queue) or len(self._send_queue)):
self._force_show_info = False
def _send_thread(self, send_tries = 100):
if self._send_event.wait(0.1):
if self._send_queue:
if self._send(*self._send_queue[0]):
self._send_queue.popleft()
self._send_try = 0
elif self._send_try > send_tries:
self._send_queue.popleft()
else:
self._send_queue.rotate(-1)
self._send_try += 1
with self._lock: # clear send flag
if not self._send_queue and not self._threads.shutdown_in_progress():
self._send_event.clear()
def _send(self, *args):
raise NotImplemented
def _recv_thread(self):
tmp = self._recv()
if tmp:
self._recv_queue.append(tmp)
with self._lock:
self._recv_event.set()
def _recv(self):
raise NotImplemented
def _close(self):
raise NotImplemented
class UDPSocket(NetworkSocket):
def __init__(self, connection):
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._sock.setblocking(0)
self._sock.bind(connection)
NetworkSocket.__init__(self, '%s:%d' % connection)
def _send(self, *args):
select.select([], [self._sock], [], 0.1)
try:
self._sock.sendto(*args)
return True
except socket.error:
pass
def _recv(self):
select.select([self._sock], [], [], 0.1)
try:
return self._sock.recvfrom(64*1024)
except socket.error:
pass
def _close(self):
self._sock.close()