avoid any ack, seq, flags... blah, blah, blah... 21 century...

utcp_uuid
Роман Бородин 2019-04-08 17:08:53 +03:00
parent 8aa6105c63
commit 544a98edd3
1 changed files with 121 additions and 112 deletions

View File

@ -6,6 +6,8 @@ import threading
import io import io
import hashlib import hashlib
import simplecrypto import simplecrypto
from struct import Struct
import uuid
DATA_DIVIDE_LENGTH = 8000 DATA_DIVIDE_LENGTH = 8000
PACKET_HEADER_SIZE = 512 # Pickle service info PACKET_HEADER_SIZE = 512 # Pickle service info
@ -43,10 +45,33 @@ class Connection:
self.ack = ack self.ack = ack
return ack return ack
class TCPPacket(object): class Ack:
def __init__(self, seq): def __init__(self, id_):
self.seq = seq self.id = id_
self.ack = 0 class Fin:
def __init__(self):
self.id = uuid.uuid4().bytes
class FinAck:
def __init__(self, id_):
self.id = id_
class Syn:
checksum = None
def __init__(self):
self.id = uuid.uuid4().bytes
def set_pub(self, pubkey):
self.checksum = TCP.checksum(pubkey)
self.pubkey = pubkey
class SynAck:
checksum = None
def __init__(self, id_):
self.id = id_
def set_pub(self, pubkey):
self.checksum = TCP.checksum(pubkey)
self.pubkey = pubkey
class Packet:
def __init__(self):
self.id = uuid.uuid4().bytes
self.flag_ack = 0 self.flag_ack = 0
self.flag_syn = 0 self.flag_syn = 0
self.flag_fin = 0 self.flag_fin = 0
@ -56,41 +81,8 @@ class TCPPacket(object):
return f'TCPpacket(type={self.packet_type()})' return f'TCPpacket(type={self.packet_type()})'
def __str__(self): def __str__(self):
return 'SEQ Number: %d, ACK Number: %d, ACK:%d, SYN:%d, FIN:%d, TYPE:%s, DATA:%s' \ return 'UUID: %d, DATA:%s' % (uuid.UUID(bytes=self.id), self.data)
% (self.seq, self.ack, self.flag_ack, self.flag_syn, self.flag_fin, self.packet_type(), self.data)
def __cmp__(self, other):
return (self.seq > other.seq) - (self.seq < other.seq)
def packet_type(self):
packet_type = ''
if self.flag_syn == 1 and self.flag_ack == 1:
packet_type = 'SYN-ACK'
elif self.flag_ack == 1 and self.flag_fin == 1:
packet_type = 'FIN-ACK'
elif self.flag_syn == 1:
packet_type = 'SYN'
elif self.flag_ack == 1:
packet_type = 'ACK'
elif self.flag_fin == 1:
packet_type = 'FIN'
elif self.data != b'':
packet_type = 'DATA'
return packet_type
def set_flags(self, ack=False, syn=False, fin=False):
if ack:
self.flag_ack = 1
else:
self.flag_ack = 0
if syn:
self.flag_syn = 1
else:
self.flag_syn = 0
if fin:
self.flag_fin = 1
else:
self.flag_fin = 0
def set_data(self, data): def set_data(self, data):
self.checksum = TCP.checksum(data) self.checksum = TCP.checksum(data)
self.data = data self.data = data
@ -99,17 +91,40 @@ class TCPPacket(object):
class RestrictedUnpickler(pickle.Unpickler): class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name): def find_class(self, module, name):
if name == 'TCPPacket': if module != 'builtins':
return TCPPacket if name == 'Packet': return Packet
if name == 'Ack': return Ack
if name == 'Fin': return Fin
if name == 'FinAck': return FinAck
if name == 'Syn': return Syn
if name == 'SynAck': return SynAck
raise pickle.UnpicklingError("global '%s.%s' is forbidden" % raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name)) (module, name))
def restricted_pickle_loads(s): def restricted_pickle_loads(s):
return RestrictedUnpickler(io.BytesIO(s)).load() return RestrictedUnpickler(io.BytesIO(s)).load()
class UTCPChannel:
HEADER = Struct('!I')
TERMINATOR = b'\n'
POLL_TIMEOUT = 0.1
def __init__(self, sock):
self.sock = sock
def poll(self):
return self.sock.poll(self.POLL_TIMEOUT)
def recv(self):
header = self.sock.recv(self.HEADER.size)
data_len = self.HEADER.unpack(header)[0]
terminated_data = self.sock.recv(data_len + len(self.TERMINATOR))
return terminated_data[:-len(self.TERMINATOR)]
def send(self, data):
header = self.HEADER.pack(len(data))
self.sock.send(header + data + self.TERMINATOR)
class ConnectedSOCK(object): class ConnectedSOCK(object):
def __init__(self, low_sock, client_addr): def __init__(self, low_sock, client_addr):
self.client_addr = client_addr self.client_addr = client_addr
self.low_sock = low_sock self.low_sock = low_sock
self.channel = UTCPChannel(self)
def __getattribute__(self, att): def __getattribute__(self, att):
try: try:
return object.__getattribute__(self, att) return object.__getattribute__(self, att)
@ -127,18 +142,23 @@ class ConnectedSOCK(object):
return self.low_sock.recv(size, self.client_addr) return self.low_sock.recv(size, self.client_addr)
@property @property
def closed(self): def closed(self):
return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin) return self.low_sock.own_socket._closed or self.client_addr not in self.low_sock.connections
def close(self): def close(self):
if self.client_addr in self.low_sock.connections: if self.client_addr in self.low_sock.connections:
self.low_sock.close(self.client_addr) self.low_sock.close(self.client_addr)
def shutdown(self, *a, **kw): def shutdown(self, *a, **kw):
self.close() self.close()
def poll(self, timeout): def poll(self, timeout):
if self.client_addr in self.packets_received['DATA or FIN']: if not self.closed:
return True conn = self.low_sock.connections[self.client_addr]
else: with conn.recv_lock:
self.incoming_packet_event.wait(timeout) has_data = self.client_addr in self.packets_received['DATA or FIN']
return self.client_addr in self.packets_received['DATA or FIN'] if has_data:
return True
else:
self.incoming_packet_event.wait(timeout)
with conn.recv_lock:
return self.client_addr in self.packets_received['DATA or FIN']
return False return False
class TCP(object): class TCP(object):
@ -153,18 +173,23 @@ class TCP(object):
self.settimeout() self.settimeout()
self.connection_lock = threading.Lock() self.connection_lock = threading.Lock()
self.queue_lock = threading.Lock() self.queue_lock = threading.Lock()
self.channel = None
self.connections = {} self.connections = {}
self.connection_queue = [] self.connection_queue = []
self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
def poll(self, timeout): def poll(self, timeout):
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']: if len(self.connections):
return True connection = list(self.connections.keys())[0]
else: conn = self.connections[connection]
self.incoming_packet_event.wait(timeout) with conn.recv_lock:
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']: has_data = connection in self.packets_received['DATA or FIN']
if has_data:
return True return True
else:
self.incoming_packet_event.wait(timeout)
with conn.recv_lock:
return connection in self.packets_received['DATA or FIN']
return False return False
def get_free_port(self): def get_free_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(('', 0)) s.bind(('', 0))
@ -205,22 +230,21 @@ class TCP(object):
data_parts = TCP.data_divider(data) data_parts = TCP.data_divider(data)
for data_part in data_parts: for data_part in data_parts:
data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part) data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part)
packet = TCPPacket(conn.seq) packet = Packet()
packet.set_data(data_chunk) packet.set_data(data_chunk)
packet_to_send = pickle.dumps(packet) packet_to_send = pickle.dumps(packet)
answer = self.retransmit(connection, packet_to_send) answer = self.retransmit(connection, packet_to_send, wnat_id=packet.id)
conn.seq_inc(len(data_part))
return len(data) return len(data)
except socket.error as error: except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error) raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
def retransmit(self, peer_addr, pickled_packet, condition='ACK'): def retransmit(self, peer_addr, pickled_packet, condition='ACK', want_id=None):
data_not_received = True data_not_received = True
retransmit_count = 0 retransmit_count = 0
while data_not_received and retransmit_count < 3: while data_not_received and retransmit_count < 3:
data_not_received = False data_not_received = False
try: try:
self.own_socket.sendto(pickled_packet, peer_addr) self.own_socket.sendto(pickled_packet, peer_addr)
answer = self.find_correct_packet(condition, peer_addr) answer = self.find_correct_packet(condition, peer_addr, want_id=want_id)
if not answer: if not answer:
data_not_received = True data_not_received = True
retransmit_count += 1 retransmit_count += 1
@ -245,11 +269,7 @@ class TCP(object):
except socket.error as error: except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error) raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
def send_ack(self, connection, ack): def send_ack(self, connection, ack):
conn = self.connections[connection] packet_to_send = pickle.dumps(ack)
ack_packet = TCPPacket(conn.seq_inc())
ack_packet.ack = conn.set_ack(ack)
ack_packet.set_flags(ack=True)
packet_to_send = pickle.dumps(ack_packet)
self.own_socket.sendto(packet_to_send, connection) self.own_socket.sendto(packet_to_send, connection)
def listen_handler(self, max_connections): def listen_handler(self, max_connections):
try: try:
@ -291,23 +311,11 @@ class TCP(object):
with self.queue_lock: with self.queue_lock:
answer, conn = self.connection_queue.pop() answer, conn = self.connection_queue.pop()
self.connections[conn.peer_addr] = conn self.connections[conn.peer_addr] = conn
packet = TCPPacket(conn.seq) syn_ack = SynAck(answer.id)
packet.ack = answer.seq + 1
packet.seq = conn.seq_inc()
packet.set_flags(ack=True, syn=True)
if self.encrypted: if self.encrypted:
packet.set_data(conn.peer_pub.encrypt_raw(conn.pubkey)) syn_ack.set_pub(conn.peer_pub.encrypt_raw(conn.pubkey))
packet_to_send = pickle.dumps(packet) packet_to_send = pickle.dumps(syn_ack)
#On packet lost retransmit self.retransmit(conn.peer_addr, packet_to_send, 'ACK', want_id=syn_ack.id)
packet_not_sent_correctly = True
while packet_not_sent_correctly or answer is None:
try:
packet_not_sent_correctly = False
self.own_socket.sendto(packet_to_send, conn.peer_addr)
answer = self.find_correct_packet('ACK', conn.peer_addr)
except socket.timeout:
packet_not_sent_correctly = True
conn.ack = answer.seq + 1
return ConnectedSOCK(self, conn.peer_addr), conn.peer_addr return ConnectedSOCK(self, conn.peer_addr), conn.peer_addr
except Exception as error: except Exception as error:
self.close(conn.peer_addr) self.close(conn.peer_addr)
@ -321,13 +329,14 @@ class TCP(object):
self.central_receive() self.central_receive()
conn = Connection(server_address, self.encrypted) conn = Connection(server_address, self.encrypted)
self.connections[server_address] = conn self.connections[server_address] = conn
syn_packet = TCPPacket(conn.seq) syn = Syn()
syn_packet.set_flags(syn=True)
if self.encrypted: if self.encrypted:
syn_packet.set_data(conn.pubkey) syn.set_pub(conn.pubkey)
first_packet_to_send = pickle.dumps(syn_packet) first_packet_to_send = pickle.dumps(syn)
self.own_socket.sendto(first_packet_to_send, server_address) try:
answer = self.find_correct_packet('SYN-ACK', server_address) answer = self.retransmit(server_address, first_packet_to_send, 'SYN-ACK', want_id=syn.id)
except EOFError:
raise EOFError('Remote peer unreachable')
if type(answer) == str: # == 'Connections full': if type(answer) == str: # == 'Connections full':
raise socket.error('Server cant receive any connections right now.') raise socket.error('Server cant receive any connections right now.')
if self.encrypted: if self.encrypted:
@ -336,12 +345,10 @@ class TCP(object):
conn.peer_pub = simplecrypto.RsaPublicKey(peer_pub) conn.peer_pub = simplecrypto.RsaPublicKey(peer_pub)
except: except:
raise socket.error('Decrypt peer public key error') raise socket.error('Decrypt peer public key error')
ack_packet = TCPPacket(conn.seq_inc()) ack = Ack(answer.id)
ack_packet.ack = conn.set_ack(answer.seq + 1) second_packet_to_send = pickle.dumps(ack)
ack_packet.set_flags(ack=True)
second_packet_to_send = pickle.dumps(ack_packet)
self.own_socket.sendto(second_packet_to_send, server_address) self.own_socket.sendto(second_packet_to_send, server_address)
self.channel = UTCPChannel(self)
except socket.error as error: except socket.error as error:
self.own_socket.close() self.own_socket.close()
self.connections = {} self.connections = {}
@ -363,18 +370,16 @@ class TCP(object):
connection = list(self.connections.keys())[0] connection = list(self.connections.keys())[0]
else: else:
raise EOFError('Connection not in connected devices') raise EOFError('Connection not in connected devices')
conn = self.connections[connection] fin = Fin()
fin_packet = TCPPacket(conn.seq_inc()) packet_to_send = pickle.dumps(fin)
fin_packet.set_flags(fin=True)
packet_to_send = pickle.dumps(fin_packet)
self.own_socket.sendto(packet_to_send, connection) self.own_socket.sendto(packet_to_send, connection)
answer = self.retransmit(connection, packet_to_send) answer = self.retransmit(connection, packet_to_send, want_id=fin.id)
conn.ack += 1 answer = self.find_correct_packet('FIN-ACK', connection, want_id=fin.id)
answer = self.find_correct_packet('FIN-ACK', connection) if not answer:
if answer.flag_fin != 1:
raise Exception('The receiver didn\'t send the fin packet') raise Exception('The receiver didn\'t send the fin packet')
else: else:
self.send_ack(connection, conn.ack + 1) ack = Ack(fin.id)
self.send_ack(connection, ack)
self.drop_connection(connection) self.drop_connection(connection)
if len(self.connections) == 0 and self.client: if len(self.connections) == 0 and self.client:
self.own_socket.close() self.own_socket.close()
@ -382,15 +387,14 @@ class TCP(object):
except Exception as error: except Exception as error:
raise EOFError('Something went wrong in the close func! Error is: %s.' % error) raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
def disconnect(self, connection): def disconnect(self, connection, fin_id):
try: try:
conn = self.connections[connection] ack = Ack(fin_id)
self.send_ack(connection, conn.set_ack(conn.ack + 1)) self.send_ack(connection, ack)
finack_packet = TCPPacket(conn.seq_inc()) fin_ack = FinAck(fin_id)
finack_packet.set_flags(fin=True, ack=True) packet_to_send = pickle.dumps(fin_ack)
packet_to_send = pickle.dumps(finack_packet)
try: try:
answer = self.retransmit(connection, packet_to_send) answer = self.retransmit(connection, packet_to_send, want_id=fin_id)
except: except:
pass pass
with self.connection_lock: with self.connection_lock:
@ -409,7 +413,7 @@ class TCP(object):
def checksum(source_bytes): def checksum(source_bytes):
return hashlib.sha1(source_bytes).digest() return hashlib.sha1(source_bytes).digest()
def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH): def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH, want_id=None):
not_found = True not_found = True
tries = 0 tries = 0
while not_found and tries < 2 and self.status: while not_found and tries < 2 and self.status:
@ -419,7 +423,7 @@ class TCP(object):
order = self.packets_received[condition].popitem() # to reverse the tuple received order = self.packets_received[condition].popitem() # to reverse the tuple received
return order[1], order[0] return order[1], order[0]
conn = self.connections[address] conn = self.connections[address]
if condition == 'ACK': if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']:
tries += 1 tries += 1
if condition == 'DATA or FIN': if condition == 'DATA or FIN':
with conn.recv_lock: with conn.recv_lock:
@ -429,6 +433,8 @@ class TCP(object):
del self.packets_received[condition][address] del self.packets_received[condition][address]
else: else:
packet = self.packets_received[condition].pop(address) packet = self.packets_received[condition].pop(address)
if want_id and packet.id != want_id:
raise KeyError
return packet return packet
except KeyError: except KeyError:
not_found = True not_found = True
@ -440,11 +446,11 @@ class TCP(object):
self.new_conn_event.set() self.new_conn_event.set()
self.new_conn_event.clear() self.new_conn_event.clear()
def sort_answers(self, packet, address): def sort_answers(self, packet, address):
if address not in self.connections and packet.packet_type() != 'SYN': if address not in self.connections and not isinstance(packet, Syn):
return return
if packet.packet_type() == 'FIN': if isinstance(packet, Fin):
self.disconnect(address) self.disconnect(address, packet.id)
elif packet.packet_type() == 'DATA': elif isinstance(packet, Packet):
if packet.checksum == TCP.checksum(packet.data): if packet.checksum == TCP.checksum(packet.data):
conn = self.connections[address] conn = self.connections[address]
data_chunk = packet.data if not self.encrypted else conn.my_key.decrypt_raw(packet.data) data_chunk = packet.data if not self.encrypted else conn.my_key.decrypt_raw(packet.data)
@ -453,7 +459,8 @@ class TCP(object):
if address not in self.packets_received['DATA or FIN']: if address not in self.packets_received['DATA or FIN']:
self.packets_received['DATA or FIN'][address] = b'' self.packets_received['DATA or FIN'][address] = b''
self.packets_received['DATA or FIN'][address] += data_chunk self.packets_received['DATA or FIN'][address] += data_chunk
self.send_ack(address, packet.seq + len(packet.data)) ack = Ack(packet.id)
self.send_ack(address, ack)
self.blink_incoming_packet_event() self.blink_incoming_packet_event()
elif packet.packet_type() == '': elif packet.packet_type() == '':
#print('redundant packet found', packet) #print('redundant packet found', packet)
@ -468,6 +475,8 @@ class TCP(object):
packet, address = self.own_socket.recvfrom(SENT_SIZE) packet, address = self.own_socket.recvfrom(SENT_SIZE)
packet = restricted_pickle_loads(packet) packet = restricted_pickle_loads(packet)
self.sort_answers(packet, address) self.sort_answers(packet, address)
except pickle.UnpicklingError:
continue
except socket.timeout: except socket.timeout:
continue continue
except socket.error as error: except socket.error as error: