stream and channel complete
parent
76b4cb0766
commit
aa9d0ed929
|
@ -5,6 +5,11 @@ from uuid import uuid4
|
||||||
import datetime, io
|
import datetime, io
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from struct import Struct
|
||||||
|
try:
|
||||||
|
import zlib
|
||||||
|
except:
|
||||||
|
zlib = None
|
||||||
|
|
||||||
peers = {}
|
peers = {}
|
||||||
|
|
||||||
|
@ -26,6 +31,7 @@ PACKET_TYPE_CONFIRM_RECV = 0xa1
|
||||||
PACKET_TYPE_GOODBUY = 0xff
|
PACKET_TYPE_GOODBUY = 0xff
|
||||||
|
|
||||||
class InvalidPacket(Exception): pass
|
class InvalidPacket(Exception): pass
|
||||||
|
class OldPacket(Exception): pass
|
||||||
|
|
||||||
def pickle_data(data):
|
def pickle_data(data):
|
||||||
return pickle.dumps(data, protocol=4)
|
return pickle.dumps(data, protocol=4)
|
||||||
|
@ -154,7 +160,7 @@ class Peer:
|
||||||
def check_received_packet(self, packet):
|
def check_received_packet(self, packet):
|
||||||
if self.last_received_packet_num_reset_time:
|
if self.last_received_packet_num_reset_time:
|
||||||
if self.last_received_packet_num_reset_time > packet.reset_timestamp:
|
if self.last_received_packet_num_reset_time > packet.reset_timestamp:
|
||||||
raise EOFError('packet from past')
|
raise OldPacket('packet from past')
|
||||||
elif self.last_received_packet_num_reset_time < packet.reset_timestamp:
|
elif self.last_received_packet_num_reset_time < packet.reset_timestamp:
|
||||||
self.last_received_packet_num_reset_time = packet.reset_timestamp
|
self.last_received_packet_num_reset_time = packet.reset_timestamp
|
||||||
if (self.last_received_packet_num + 1) != packet.num:
|
if (self.last_received_packet_num + 1) != packet.num:
|
||||||
|
@ -206,11 +212,14 @@ class Peer:
|
||||||
############################################
|
############################################
|
||||||
else:
|
else:
|
||||||
if packet.type == PACKET_TYPE_PACKET:
|
if packet.type == PACKET_TYPE_PACKET:
|
||||||
|
try:
|
||||||
|
self.check_received_packet(packet)
|
||||||
|
except OldPacket:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
raw = self.my_key.decrypt_raw(packet.data)
|
raw = self.my_key.decrypt_raw(packet.data)
|
||||||
except:
|
except:
|
||||||
raise EOFError('decrypt packet error')
|
raise EOFError('decrypt packet error')
|
||||||
self.check_received_packet(packet)
|
|
||||||
self.put_block(raw)
|
self.put_block(raw)
|
||||||
else:
|
else:
|
||||||
raise EOFError('connection lost')
|
raise EOFError('connection lost')
|
||||||
|
@ -248,7 +257,6 @@ udpserver_thread = threading.Thread(target=udpserver.serve_forever)
|
||||||
udpserver_thread.start()
|
udpserver_thread.start()
|
||||||
|
|
||||||
class EncryptedUDPStream:
|
class EncryptedUDPStream:
|
||||||
MAX_IO_CHUNK = 8000
|
|
||||||
def __init__(self, sock, peer_addr):
|
def __init__(self, sock, peer_addr):
|
||||||
self.peer_addr = peer_addr
|
self.peer_addr = peer_addr
|
||||||
self.sock = sock
|
self.sock = sock
|
||||||
|
@ -280,7 +288,7 @@ class EncryptedUDPStream:
|
||||||
except:
|
except:
|
||||||
self.close()
|
self.close()
|
||||||
raise EOFError
|
raise EOFError
|
||||||
def read(self, count):
|
def read(self):
|
||||||
try:
|
try:
|
||||||
buf = peers[self.peer_addr].get_next_block()
|
buf = peers[self.peer_addr].get_next_block()
|
||||||
except:
|
except:
|
||||||
|
@ -292,4 +300,53 @@ class EncryptedUDPStream:
|
||||||
except:
|
except:
|
||||||
raise EOFError
|
raise EOFError
|
||||||
|
|
||||||
|
class Channel(object):
|
||||||
|
MAX_IO_CHUNK = 8000
|
||||||
|
COMPRESSION_THRESHOLD = 3000
|
||||||
|
COMPRESSION_LEVEL = 1
|
||||||
|
FRAME_HEADER = Struct("!LB")
|
||||||
|
FLUSHER = b'\n'
|
||||||
|
__slots__ = ["stream", "compress"]
|
||||||
|
|
||||||
|
def __init__(self, stream, compress = True):
|
||||||
|
self.stream = stream
|
||||||
|
if not zlib:
|
||||||
|
compress = False
|
||||||
|
self.compress = compress
|
||||||
|
def close(self):
|
||||||
|
self.stream.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self):
|
||||||
|
return self.stream.closed
|
||||||
|
def fileno(self):
|
||||||
|
return self.stream.fileno()
|
||||||
|
|
||||||
|
def poll(self, timeout):
|
||||||
|
return self.stream.poll(timeout)
|
||||||
|
|
||||||
|
def recv(self):
|
||||||
|
header = self.stream.read()
|
||||||
|
if len(header) != self.FRAME_HEADER.size:
|
||||||
|
raise EOFError('CHANNEL: Not a header received')
|
||||||
|
length, compressed = self.FRAME_HEADER.unpack(header)
|
||||||
|
block_len = length + len(self.FLUSHER)
|
||||||
|
full_block = b''.join((self.stream.read() for x in range(0, block_len, self.MAX_IO_CHUNK)))
|
||||||
|
if len(full_block) != block_len:
|
||||||
|
raise EOFError('CHANNEL: Received block with wrong size')
|
||||||
|
data = full_block[:-len(self.FLUSHER)]
|
||||||
|
if compressed:
|
||||||
|
data = zlib.decompress(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def send(self, data):
|
||||||
|
if self.compress and len(data) > self.COMPRESSION_THRESHOLD:
|
||||||
|
compressed = 1
|
||||||
|
data = zlib.compress(data, self.COMPRESSION_LEVEL)
|
||||||
|
else:
|
||||||
|
compressed = 0
|
||||||
|
header = self.FRAME_HEADER.pack(len(data), compressed)
|
||||||
|
self.stream.write(header)
|
||||||
|
buf = data + self.FLUSHER
|
||||||
|
for chunk_start in range(0, len(buf), self.MAX_IO_CHUNK):
|
||||||
|
self.stream.write(buf[chunk_start:self.MAX_IO_CHUNK])
|
||||||
|
|
Loading…
Reference in New Issue