wait for packet arrive

master
Роман Бородин 2019-04-11 19:33:44 +03:00
parent 4fff6e42fb
commit 328692bd20
2 changed files with 33 additions and 30 deletions

View File

@ -15,7 +15,7 @@ class UTCPSocketStream(SocketStream):
MAX_IO_CHUNK = utcp.DATA_LENGTH MAX_IO_CHUNK = utcp.DATA_LENGTH
@classmethod @classmethod
def utcp_connect(cls, host, port, *a, **kw): def utcp_connect(cls, host, port, *a, **kw):
sock = utcp.TCP(encrypted=True) sock = utcp.UTCP(encrypted=True)
sock.connect((host, port)) sock.connect((host, port))
return cls(sock) return cls(sock)
def poll(self, timeout): def poll(self, timeout):
@ -54,7 +54,7 @@ class UTCPThreadedServer(ThreadedServer):
self.listener.close() self.listener.close()
self.listener = None self.listener = None
########## ##########
self.listener = utcp.TCP(encrypted=True) self.listener = utcp.UTCP(encrypted=True)
self.listener.bind((hostname, port)) self.listener.bind((hostname, port))
sockname = self.listener.getsockname() sockname = self.listener.getsockname()
self.host, self.port = sockname[0], sockname[1] self.host, self.port = sockname[0], sockname[1]

View File

@ -1,4 +1,4 @@
# based on https://github.com/ethay012/TCP-over-UDP # based on https://github.com/ethay012/UTCP-over-UDP
import random import random
import socket import socket
import pickle import pickle
@ -14,9 +14,6 @@ DATA_DIVIDE_LENGTH = 8000
PACKET_HEADER_SIZE = 512 # Pickle service info PACKET_HEADER_SIZE = 512 # Pickle service info
DATA_LENGTH = DATA_DIVIDE_LENGTH DATA_LENGTH = DATA_DIVIDE_LENGTH
SENT_SIZE = PACKET_HEADER_SIZE + DATA_LENGTH + 272 # Encrypted data always 272 bytes bigger SENT_SIZE = PACKET_HEADER_SIZE + DATA_LENGTH + 272 # Encrypted data always 272 bytes bigger
LAST_CONNECTION = -1
FIRST = 0
PACKET_END = b'___+++^^^END^^^+++___'
class Connection: class Connection:
SMALLEST_STARTING_SEQ = 0 SMALLEST_STARTING_SEQ = 0
@ -24,7 +21,7 @@ class Connection:
def __init__(self, remote, encrypted=False): def __init__(self, remote, encrypted=False):
self.peer_addr = remote self.peer_addr = remote
self.seq = Connection.gen_starting_seq_num() self.seq = Connection.gen_starting_seq_num()
self.recv_seq = 0 self.recv_seq = -1
self.my_key = None self.my_key = None
if encrypted: if encrypted:
self.my_key = simplecrypto.RsaKeypair() self.my_key = simplecrypto.RsaKeypair()
@ -80,7 +77,7 @@ class Syn(UTCPPacket):
self.id = uuid.uuid4().bytes self.id = uuid.uuid4().bytes
self.seq = 0 self.seq = 0
def set_pub(self, pubkey): def set_pub(self, pubkey):
self.checksum = TCP.checksum(pubkey) self.checksum = UTCP.checksum(pubkey)
self.pubkey = pubkey self.pubkey = pubkey
class SynAck(UTCPPacket): class SynAck(UTCPPacket):
type = 'SYN-ACK' type = 'SYN-ACK'
@ -89,7 +86,7 @@ class SynAck(UTCPPacket):
self.id = id_ self.id = id_
self.seq = 0 self.seq = 0
def set_pub(self, pubkey): def set_pub(self, pubkey):
self.checksum = TCP.checksum(pubkey) self.checksum = UTCP.checksum(pubkey)
self.pubkey = pubkey self.pubkey = pubkey
class Packet(UTCPPacket): class Packet(UTCPPacket):
@ -106,7 +103,7 @@ class Packet(UTCPPacket):
return 'UUID: %d, DATA:%s' % (uuid.UUID(bytes=self.id), self.data) return 'UUID: %d, DATA:%s' % (uuid.UUID(bytes=self.id), self.data)
def set_data(self, data): def set_data(self, data):
self.checksum = TCP.checksum(data) self.checksum = UTCP.checksum(data)
self.data = data self.data = data
@ -127,7 +124,6 @@ def restricted_pickle_loads(s):
class UTCPChannel: class UTCPChannel:
HEADER = Struct('!I') HEADER = Struct('!I')
TERMINATOR = b'\n'
POLL_TIMEOUT = 0.1 POLL_TIMEOUT = 0.1
def __init__(self, sock): def __init__(self, sock):
self.sock = sock self.sock = sock
@ -136,11 +132,10 @@ class UTCPChannel:
def recv(self): def recv(self):
header = self.sock.recv(self.HEADER.size) header = self.sock.recv(self.HEADER.size)
data_len = self.HEADER.unpack(header)[0] data_len = self.HEADER.unpack(header)[0]
terminated_data = self.sock.recv(data_len + len(self.TERMINATOR)) return self.sock.recv(data_len)
return terminated_data[:-len(self.TERMINATOR)]
def send(self, data): def send(self, data):
header = self.HEADER.pack(len(data)) header = self.HEADER.pack(len(data))
self.sock.send(header + data + self.TERMINATOR) self.sock.send(header + data)
class ConnectedSOCK(object): class ConnectedSOCK(object):
def __init__(self, low_sock, client_addr): def __init__(self, low_sock, client_addr):
@ -183,15 +178,16 @@ class ConnectedSOCK(object):
return len(conn.packet_buffer['DATA']) return len(conn.packet_buffer['DATA'])
return False return False
class TCP(object): class UTCP(object):
host = None host = None
port = None port = None
client = False client = False
def __init__(self, af_type=None, sock_type=None, encrypted=False): def __init__(self,encrypted=False, **kw):
self.encrypted = encrypted self.encrypted = encrypted
self.incoming_packet_event = threading.Event() self.incoming_packet_event = threading.Event()
self.new_conn_event = threading.Event() self.new_conn_event = threading.Event()
self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.own_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.settimeout() self.settimeout()
self.connection_lock = threading.Lock() self.connection_lock = threading.Lock()
self.queue_lock = threading.Lock() self.queue_lock = threading.Lock()
@ -204,13 +200,13 @@ class TCP(object):
connection = list(self.connections.keys())[0] connection = list(self.connections.keys())[0]
conn = self.connections[connection] conn = self.connections[connection]
with conn.recv_lock: with conn.recv_lock:
has_data = len(conn.packet_buffer['DATA']) has_data = bool(len(conn.packet_buffer['DATA']))
if has_data: if has_data:
return True return True
else: else:
self.incoming_packet_event.wait(timeout) self.incoming_packet_event.wait(timeout)
with conn.recv_lock: with conn.recv_lock:
return len(conn.packet_buffer['DATA']) return bool(len(conn.packet_buffer['DATA']))
return False return False
def get_free_port(self): def get_free_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -225,7 +221,7 @@ class TCP(object):
def setblocking(self, mode): def setblocking(self, mode):
self.own_socket.setblocking(mode) self.own_socket.setblocking(mode)
def __repr__(self): def __repr__(self):
return 'TCP()' return 'UTCP()'
def __str__(self): def __str__(self):
return 'Connections: %s' \ return 'Connections: %s' \
@ -251,12 +247,12 @@ class TCP(object):
else: else:
raise EOFError('Connection not in connected devices') raise EOFError('Connection not in connected devices')
conn = self.connections[connection] conn = self.connections[connection]
data_parts = TCP.data_divider(data) data_parts = UTCP.data_divider(data)
for data_part in data_parts: for data_part in data_parts:
data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part) data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part)
packet = Packet() packet = Packet()
packet.set_data(data_chunk) packet.set_data(data_chunk)
answer = self.__send_packet(connection, packet, retransmit=True) self.__send_packet(connection, packet, retransmit=True)
return len(data) return len(data)
except socket.error as error: except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error) raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
@ -341,6 +337,7 @@ class TCP(object):
if self.connection_queue: if self.connection_queue:
with self.queue_lock: with self.queue_lock:
answer, conn = self.connection_queue.pop() answer, conn = self.connection_queue.pop()
conn.recv_seq = answer.seq
self.connections[conn.peer_addr] = conn self.connections[conn.peer_addr] = conn
syn_ack = SynAck(answer.id) syn_ack = SynAck(answer.id)
if self.encrypted: if self.encrypted:
@ -410,7 +407,7 @@ class TCP(object):
raise Exception('The receiver didn\'t send the fin packet') raise Exception('The receiver didn\'t send the fin packet')
else: else:
self.drop_connection(connection) self.drop_connection(connection)
if len(self.connections) == 0 and self.client: if not len(self.connections) and self.client:
self.stop() self.stop()
except Exception as error: except Exception as error:
raise EOFError('Something went wrong in the close func! Error is: %s.' % error) raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
@ -421,7 +418,7 @@ class TCP(object):
self.__send_packet(connection, ack) self.__send_packet(connection, ack)
fin_ack = FinAck(fin_id) fin_ack = FinAck(fin_id)
try: try:
answer = self.__send_packet(connection, fin_ack) self.__send_packet(connection, fin_ack)
except: except:
pass pass
with self.connection_lock: with self.connection_lock:
@ -454,7 +451,9 @@ class TCP(object):
if condition == 'DATA': if condition == 'DATA':
if len(conn.packet_buffer[condition]): if len(conn.packet_buffer[condition]):
data = b'' data = b''
while size and conn.packet_buffer[condition]: while size:
if not self.poll(0.5):
continue
with conn.recv_lock: with conn.recv_lock:
packet = conn.packet_buffer[condition][0] packet = conn.packet_buffer[condition][0]
chunk = packet.data[:size] chunk = packet.data[:size]
@ -488,21 +487,25 @@ class TCP(object):
def blink_new_conn_event(self): def blink_new_conn_event(self):
self.new_conn_event.set() self.new_conn_event.set()
self.new_conn_event.clear() self.new_conn_event.clear()
def sort_answers(self, packet, address): def multiplex(self, packet, address):
if address not in self.connections and not isinstance(packet, Syn): if address not in self.connections and not isinstance(packet, Syn):
return return
if not isinstance(packet, Syn): if not isinstance(packet, Syn):
conn = self.connections[address] conn = self.connections[address]
if conn.recv_seq == packet.seq: if isinstance(packet, SynAck):
conn.recv_seq = packet.seq
elif conn.recv_seq == packet.seq: # Repeat ACK
ack = Ack(packet.id) ack = Ack(packet.id)
self.__send_packet(address, ack) self.__send_packet(address, ack)
return return
elif conn.recv_seq > packet.seq: elif packet.seq < conn.recv_seq: # Possibly DUP
return
elif packet.seq > (conn.recv_seq + 1): # Intermediate packet lost
return return
else: else:
conn.recv_seq = packet.seq conn.recv_seq = packet.seq
if isinstance(packet, Packet): if isinstance(packet, Packet):
if packet.checksum == TCP.checksum(packet.data): if packet.checksum == UTCP.checksum(packet.data):
if self.encrypted: if self.encrypted:
packet.data = conn.my_key.decrypt_raw(packet.data) packet.data = conn.my_key.decrypt_raw(packet.data)
else: else:
@ -526,12 +529,12 @@ class TCP(object):
try: try:
packet, address = self.own_socket.recvfrom(SENT_SIZE) packet, address = self.own_socket.recvfrom(SENT_SIZE)
packet = restricted_pickle_loads(packet) packet = restricted_pickle_loads(packet)
self.sort_answers(packet, address) self.multiplex(packet, address)
except pickle.UnpicklingError: except pickle.UnpicklingError:
continue continue
except socket.timeout: except socket.timeout:
continue continue
except socket.error as error: except socket.error:
self.own_socket.close() self.own_socket.close()
self.status = 0 self.status = 0