diff --git a/mods/utcp.py b/mods/utcp.py index 647f42f..65f39f8 100644 --- a/mods/utcp.py +++ b/mods/utcp.py @@ -6,6 +6,8 @@ import threading import io import hashlib import simplecrypto +from struct import Struct +import uuid DATA_DIVIDE_LENGTH = 8000 PACKET_HEADER_SIZE = 512 # Pickle service info @@ -43,10 +45,33 @@ class Connection: self.ack = ack return ack -class TCPPacket(object): - def __init__(self, seq): - self.seq = seq - self.ack = 0 +class Ack: + def __init__(self, id_): + self.id = id_ +class Fin: + def __init__(self): + self.id = uuid.uuid4().bytes +class FinAck: + def __init__(self, id_): + self.id = id_ +class Syn: + checksum = None + def __init__(self): + self.id = uuid.uuid4().bytes + def set_pub(self, pubkey): + self.checksum = TCP.checksum(pubkey) + self.pubkey = pubkey +class SynAck: + checksum = None + def __init__(self, id_): + self.id = id_ + def set_pub(self, pubkey): + self.checksum = TCP.checksum(pubkey) + self.pubkey = pubkey + +class Packet: + def __init__(self): + self.id = uuid.uuid4().bytes self.flag_ack = 0 self.flag_syn = 0 self.flag_fin = 0 @@ -56,41 +81,8 @@ class TCPPacket(object): return f'TCPpacket(type={self.packet_type()})' def __str__(self): - return 'SEQ Number: %d, ACK Number: %d, ACK:%d, SYN:%d, FIN:%d, TYPE:%s, DATA:%s' \ - % (self.seq, self.ack, self.flag_ack, self.flag_syn, self.flag_fin, self.packet_type(), self.data) + return 'UUID: %d, DATA:%s' % (uuid.UUID(bytes=self.id), self.data) - def __cmp__(self, other): - return (self.seq > other.seq) - (self.seq < other.seq) - - def packet_type(self): - packet_type = '' - if self.flag_syn == 1 and self.flag_ack == 1: - packet_type = 'SYN-ACK' - elif self.flag_ack == 1 and self.flag_fin == 1: - packet_type = 'FIN-ACK' - elif self.flag_syn == 1: - packet_type = 'SYN' - elif self.flag_ack == 1: - packet_type = 'ACK' - elif self.flag_fin == 1: - packet_type = 'FIN' - elif self.data != b'': - packet_type = 'DATA' - return packet_type - - def set_flags(self, ack=False, syn=False, fin=False): - if ack: - self.flag_ack = 1 - else: - self.flag_ack = 0 - if syn: - self.flag_syn = 1 - else: - self.flag_syn = 0 - if fin: - self.flag_fin = 1 - else: - self.flag_fin = 0 def set_data(self, data): self.checksum = TCP.checksum(data) self.data = data @@ -99,17 +91,40 @@ class TCPPacket(object): class RestrictedUnpickler(pickle.Unpickler): def find_class(self, module, name): - if name == 'TCPPacket': - return TCPPacket + if module != 'builtins': + if name == 'Packet': return Packet + if name == 'Ack': return Ack + if name == 'Fin': return Fin + if name == 'FinAck': return FinAck + if name == 'Syn': return Syn + if name == 'SynAck': return SynAck raise pickle.UnpicklingError("global '%s.%s' is forbidden" % (module, name)) def restricted_pickle_loads(s): return RestrictedUnpickler(io.BytesIO(s)).load() +class UTCPChannel: + HEADER = Struct('!I') + TERMINATOR = b'\n' + POLL_TIMEOUT = 0.1 + def __init__(self, sock): + self.sock = sock + def poll(self): + return self.sock.poll(self.POLL_TIMEOUT) + def recv(self): + header = self.sock.recv(self.HEADER.size) + data_len = self.HEADER.unpack(header)[0] + terminated_data = self.sock.recv(data_len + len(self.TERMINATOR)) + return terminated_data[:-len(self.TERMINATOR)] + def send(self, data): + header = self.HEADER.pack(len(data)) + self.sock.send(header + data + self.TERMINATOR) + class ConnectedSOCK(object): def __init__(self, low_sock, client_addr): self.client_addr = client_addr self.low_sock = low_sock + self.channel = UTCPChannel(self) def __getattribute__(self, att): try: return object.__getattribute__(self, att) @@ -127,18 +142,23 @@ class ConnectedSOCK(object): return self.low_sock.recv(size, self.client_addr) @property def closed(self): - return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin) + return self.low_sock.own_socket._closed or self.client_addr not in self.low_sock.connections def close(self): if self.client_addr in self.low_sock.connections: self.low_sock.close(self.client_addr) def shutdown(self, *a, **kw): self.close() def poll(self, timeout): - if self.client_addr in self.packets_received['DATA or FIN']: - return True - else: - self.incoming_packet_event.wait(timeout) - return self.client_addr in self.packets_received['DATA or FIN'] + if not self.closed: + conn = self.low_sock.connections[self.client_addr] + with conn.recv_lock: + has_data = self.client_addr in self.packets_received['DATA or FIN'] + if has_data: + return True + else: + self.incoming_packet_event.wait(timeout) + with conn.recv_lock: + return self.client_addr in self.packets_received['DATA or FIN'] return False class TCP(object): @@ -153,18 +173,23 @@ class TCP(object): self.settimeout() self.connection_lock = threading.Lock() self.queue_lock = threading.Lock() + self.channel = None self.connections = {} self.connection_queue = [] self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} def poll(self, timeout): - if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']: - return True - else: - self.incoming_packet_event.wait(timeout) - if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']: + if len(self.connections): + connection = list(self.connections.keys())[0] + conn = self.connections[connection] + with conn.recv_lock: + has_data = connection in self.packets_received['DATA or FIN'] + if has_data: return True + else: + self.incoming_packet_event.wait(timeout) + with conn.recv_lock: + return connection in self.packets_received['DATA or FIN'] return False - def get_free_port(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.bind(('', 0)) @@ -205,22 +230,21 @@ class TCP(object): data_parts = TCP.data_divider(data) for data_part in data_parts: data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part) - packet = TCPPacket(conn.seq) + packet = Packet() packet.set_data(data_chunk) packet_to_send = pickle.dumps(packet) - answer = self.retransmit(connection, packet_to_send) - conn.seq_inc(len(data_part)) + answer = self.retransmit(connection, packet_to_send, wnat_id=packet.id) return len(data) except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) - def retransmit(self, peer_addr, pickled_packet, condition='ACK'): + def retransmit(self, peer_addr, pickled_packet, condition='ACK', want_id=None): data_not_received = True retransmit_count = 0 while data_not_received and retransmit_count < 3: data_not_received = False try: self.own_socket.sendto(pickled_packet, peer_addr) - answer = self.find_correct_packet(condition, peer_addr) + answer = self.find_correct_packet(condition, peer_addr, want_id=want_id) if not answer: data_not_received = True retransmit_count += 1 @@ -245,11 +269,7 @@ class TCP(object): except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) def send_ack(self, connection, ack): - conn = self.connections[connection] - ack_packet = TCPPacket(conn.seq_inc()) - ack_packet.ack = conn.set_ack(ack) - ack_packet.set_flags(ack=True) - packet_to_send = pickle.dumps(ack_packet) + packet_to_send = pickle.dumps(ack) self.own_socket.sendto(packet_to_send, connection) def listen_handler(self, max_connections): try: @@ -291,23 +311,11 @@ class TCP(object): with self.queue_lock: answer, conn = self.connection_queue.pop() self.connections[conn.peer_addr] = conn - packet = TCPPacket(conn.seq) - packet.ack = answer.seq + 1 - packet.seq = conn.seq_inc() - packet.set_flags(ack=True, syn=True) + syn_ack = SynAck(answer.id) if self.encrypted: - packet.set_data(conn.peer_pub.encrypt_raw(conn.pubkey)) - packet_to_send = pickle.dumps(packet) - #On packet lost retransmit - packet_not_sent_correctly = True - while packet_not_sent_correctly or answer is None: - try: - packet_not_sent_correctly = False - self.own_socket.sendto(packet_to_send, conn.peer_addr) - answer = self.find_correct_packet('ACK', conn.peer_addr) - except socket.timeout: - packet_not_sent_correctly = True - conn.ack = answer.seq + 1 + syn_ack.set_pub(conn.peer_pub.encrypt_raw(conn.pubkey)) + packet_to_send = pickle.dumps(syn_ack) + self.retransmit(conn.peer_addr, packet_to_send, 'ACK', want_id=syn_ack.id) return ConnectedSOCK(self, conn.peer_addr), conn.peer_addr except Exception as error: self.close(conn.peer_addr) @@ -321,13 +329,14 @@ class TCP(object): self.central_receive() conn = Connection(server_address, self.encrypted) self.connections[server_address] = conn - syn_packet = TCPPacket(conn.seq) - syn_packet.set_flags(syn=True) + syn = Syn() if self.encrypted: - syn_packet.set_data(conn.pubkey) - first_packet_to_send = pickle.dumps(syn_packet) - self.own_socket.sendto(first_packet_to_send, server_address) - answer = self.find_correct_packet('SYN-ACK', server_address) + syn.set_pub(conn.pubkey) + first_packet_to_send = pickle.dumps(syn) + try: + answer = self.retransmit(server_address, first_packet_to_send, 'SYN-ACK', want_id=syn.id) + except EOFError: + raise EOFError('Remote peer unreachable') if type(answer) == str: # == 'Connections full': raise socket.error('Server cant receive any connections right now.') if self.encrypted: @@ -336,12 +345,10 @@ class TCP(object): conn.peer_pub = simplecrypto.RsaPublicKey(peer_pub) except: raise socket.error('Decrypt peer public key error') - ack_packet = TCPPacket(conn.seq_inc()) - ack_packet.ack = conn.set_ack(answer.seq + 1) - ack_packet.set_flags(ack=True) - second_packet_to_send = pickle.dumps(ack_packet) + ack = Ack(answer.id) + second_packet_to_send = pickle.dumps(ack) self.own_socket.sendto(second_packet_to_send, server_address) - + self.channel = UTCPChannel(self) except socket.error as error: self.own_socket.close() self.connections = {} @@ -363,18 +370,16 @@ class TCP(object): connection = list(self.connections.keys())[0] else: raise EOFError('Connection not in connected devices') - conn = self.connections[connection] - fin_packet = TCPPacket(conn.seq_inc()) - fin_packet.set_flags(fin=True) - packet_to_send = pickle.dumps(fin_packet) + fin = Fin() + packet_to_send = pickle.dumps(fin) self.own_socket.sendto(packet_to_send, connection) - answer = self.retransmit(connection, packet_to_send) - conn.ack += 1 - answer = self.find_correct_packet('FIN-ACK', connection) - if answer.flag_fin != 1: + answer = self.retransmit(connection, packet_to_send, want_id=fin.id) + answer = self.find_correct_packet('FIN-ACK', connection, want_id=fin.id) + if not answer: raise Exception('The receiver didn\'t send the fin packet') else: - self.send_ack(connection, conn.ack + 1) + ack = Ack(fin.id) + self.send_ack(connection, ack) self.drop_connection(connection) if len(self.connections) == 0 and self.client: self.own_socket.close() @@ -382,15 +387,14 @@ class TCP(object): except Exception as error: raise EOFError('Something went wrong in the close func! Error is: %s.' % error) - def disconnect(self, connection): + def disconnect(self, connection, fin_id): try: - conn = self.connections[connection] - self.send_ack(connection, conn.set_ack(conn.ack + 1)) - finack_packet = TCPPacket(conn.seq_inc()) - finack_packet.set_flags(fin=True, ack=True) - packet_to_send = pickle.dumps(finack_packet) + ack = Ack(fin_id) + self.send_ack(connection, ack) + fin_ack = FinAck(fin_id) + packet_to_send = pickle.dumps(fin_ack) try: - answer = self.retransmit(connection, packet_to_send) + answer = self.retransmit(connection, packet_to_send, want_id=fin_id) except: pass with self.connection_lock: @@ -409,7 +413,7 @@ class TCP(object): def checksum(source_bytes): return hashlib.sha1(source_bytes).digest() - def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH): + def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH, want_id=None): not_found = True tries = 0 while not_found and tries < 2 and self.status: @@ -419,7 +423,7 @@ class TCP(object): order = self.packets_received[condition].popitem() # to reverse the tuple received return order[1], order[0] conn = self.connections[address] - if condition == 'ACK': + if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']: tries += 1 if condition == 'DATA or FIN': with conn.recv_lock: @@ -429,6 +433,8 @@ class TCP(object): del self.packets_received[condition][address] else: packet = self.packets_received[condition].pop(address) + if want_id and packet.id != want_id: + raise KeyError return packet except KeyError: not_found = True @@ -440,11 +446,11 @@ class TCP(object): self.new_conn_event.set() self.new_conn_event.clear() def sort_answers(self, packet, address): - if address not in self.connections and packet.packet_type() != 'SYN': + if address not in self.connections and not isinstance(packet, Syn): return - if packet.packet_type() == 'FIN': - self.disconnect(address) - elif packet.packet_type() == 'DATA': + if isinstance(packet, Fin): + self.disconnect(address, packet.id) + elif isinstance(packet, Packet): if packet.checksum == TCP.checksum(packet.data): conn = self.connections[address] data_chunk = packet.data if not self.encrypted else conn.my_key.decrypt_raw(packet.data) @@ -453,7 +459,8 @@ class TCP(object): if address not in self.packets_received['DATA or FIN']: self.packets_received['DATA or FIN'][address] = b'' self.packets_received['DATA or FIN'][address] += data_chunk - self.send_ack(address, packet.seq + len(packet.data)) + ack = Ack(packet.id) + self.send_ack(address, ack) self.blink_incoming_packet_event() elif packet.packet_type() == '': #print('redundant packet found', packet) @@ -468,6 +475,8 @@ class TCP(object): packet, address = self.own_socket.recvfrom(SENT_SIZE) packet = restricted_pickle_loads(packet) self.sort_answers(packet, address) + except pickle.UnpicklingError: + continue except socket.timeout: continue except socket.error as error: