From 11b7ec1ee27fc6d5a470ea5a00b1e60299de4032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A0=D0=BE=D0=BC=D0=B0=D0=BD=20=D0=91=D0=BE=D1=80=D0=BE?= =?UTF-8?q?=D0=B4=D0=B8=D0=BD?= Date: Mon, 8 Apr 2019 22:15:04 +0300 Subject: [PATCH] working version --- mods/utcp.py | 182 +++++++++++++++++++++++++++++---------------------- 1 file changed, 103 insertions(+), 79 deletions(-) diff --git a/mods/utcp.py b/mods/utcp.py index 21053a5..cad67b3 100644 --- a/mods/utcp.py +++ b/mods/utcp.py @@ -24,6 +24,7 @@ class Connection: def __init__(self, remote, encrypted=False): self.peer_addr = remote self.seq = Connection.gen_starting_seq_num() + self.recv_seq = 0 self.my_key = None if encrypted: self.my_key = simplecrypto.RsaKeypair() @@ -50,49 +51,56 @@ class Connection: class UTCPPacket: def __cmp__(self, other): return (self.seq > other.seq) - (self.seq < other.seq) + def __lt__(self, other): + return self.seq < other.seq + def __gt__(self, other): + return self.seq > other.seq + def __eq__(self, other): + return self.seq == other.seq class Ack(UTCPPacket): type = 'ACK' - def __init__(self, id_, seq): + def __init__(self, id_): self.id = id_ - self.seq = seq + self.seq = 0 class Fin(UTCPPacket): type = 'FIN' - def __init__(self, seq): + def __init__(self): self.id = uuid.uuid4().bytes - self.seq = seq + self.seq = 0 class FinAck(UTCPPacket): type = 'FIN-ACK' - def __init__(self, id_, seq): + def __init__(self, id_): self.id = id_ - self.seq = seq + self.seq = 0 class Syn(UTCPPacket): type = 'SYN' checksum = None - def __init__(self, seq): + def __init__(self): self.id = uuid.uuid4().bytes - self.seq = seq + self.seq = 0 def set_pub(self, pubkey): self.checksum = TCP.checksum(pubkey) self.pubkey = pubkey class SynAck(UTCPPacket): type = 'SYN-ACK' checksum = None - def __init__(self, id_, seq): + def __init__(self, id_): self.id = id_ - self.seq = seq + self.seq = 0 def set_pub(self, pubkey): self.checksum = TCP.checksum(pubkey) self.pubkey = pubkey class Packet(UTCPPacket): - def __init__(self, seq): + type = 'DATA' + def __init__(self): self.id = uuid.uuid4().bytes self.checksum = 0 self.data = b'' - self.seq = seq + self.seq = 0 def __repr__(self): - return f'TCPpacket(type={self.packet_type()})' + return f'Packet(seq={self.seq})' def __str__(self): return 'UUID: %d, DATA:%s' % (uuid.UUID(bytes=self.id), self.data) @@ -246,28 +254,10 @@ class TCP(object): data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part) packet = Packet() packet.set_data(data_chunk) - packet_to_send = pickle.dumps(packet) - answer = self.retransmit(connection, packet_to_send, want_id=packet.id) + answer = self.__send_packet(connection, packet, retransmit=True) return len(data) except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) - def retransmit(self, peer_addr, pickled_packet, condition='ACK', want_id=None): - data_not_received = True - retransmit_count = 0 - while data_not_received and retransmit_count < 3: - data_not_received = False - try: - self.own_socket.sendto(pickled_packet, peer_addr) - answer = self.find_correct_packet(condition, peer_addr, want_id=want_id) - if not answer: - data_not_received = True - retransmit_count += 1 - except socket.timeout: - data_not_received = True - if not answer: - self.drop_connection(peer_addr) - raise EOFError('Connection lost') - return answer def recv(self, size, connection=None): try: if connection not in list(self.connections.keys()): @@ -275,16 +265,35 @@ class TCP(object): connection = list(self.connections.keys())[0] else: raise EOFError('Connection not in connected devices') - data = self.find_correct_packet('DATA', connection, size) if not self.status: raise EOFError('Disconnecting') return data except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) - def send_ack(self, connection, ack): - packet_to_send = pickle.dumps(ack) - self.own_socket.sendto(packet_to_send, connection) + def __send_packet(self, peer_addr, packet, retransmit=False, wait_cond='ACK'): + conn = self.connections[peer_addr] + packet.seq = conn.seq_inc() + packet_to_send = pickle.dumps(packet) + if not retransmit: + self.own_socket.sendto(packet_to_send, peer_addr) + else: + data_not_received = True + retransmit_count = 0 + while data_not_received and retransmit_count < 3: + data_not_received = False + try: + self.own_socket.sendto(packet_to_send, peer_addr) + answer = self.find_correct_packet(wait_cond, peer_addr, want_id=packet.id) + if not answer: + data_not_received = True + retransmit_count += 1 + except socket.timeout: + data_not_received = True + if not answer: + self.drop_connection(peer_addr) + raise EOFError('Connection lost') + return answer def listen_handler(self, max_connections): try: while True and self.status: @@ -301,7 +310,9 @@ class TCP(object): self.connection_queue.append((answer, conn)) self.blink_new_conn_event() else: - self.own_socket.sendto('Connections full', address) + if answer.id in map(lambda x: x[0].id, self.connection_queue): + continue + self.own_socket.sendto(b'Connections full', address) except KeyError: continue except socket.error as error: @@ -328,8 +339,7 @@ class TCP(object): syn_ack = SynAck(answer.id) if self.encrypted: syn_ack.set_pub(conn.peer_pub.encrypt_raw(conn.pubkey)) - packet_to_send = pickle.dumps(syn_ack) - self.retransmit(conn.peer_addr, packet_to_send, 'ACK', want_id=syn_ack.id) + answer = self.__send_packet(conn.peer_addr, syn_ack, retransmit=True) return ConnectedSOCK(self, conn.peer_addr), conn.peer_addr except Exception as error: self.close(conn.peer_addr) @@ -346,9 +356,8 @@ class TCP(object): syn = Syn() if self.encrypted: syn.set_pub(conn.pubkey) - first_packet_to_send = pickle.dumps(syn) try: - answer = self.retransmit(server_address, first_packet_to_send, 'SYN-ACK', want_id=syn.id) + answer = self.__send_packet(server_address, syn, retransmit=True, wait_cond='SYN-ACK') except EOFError: raise EOFError('Remote peer unreachable') if type(answer) == str: # == 'Connections full': @@ -360,8 +369,7 @@ class TCP(object): except: raise socket.error('Decrypt peer public key error') ack = Ack(answer.id) - second_packet_to_send = pickle.dumps(ack) - self.own_socket.sendto(second_packet_to_send, server_address) + self.__send_packet(server_address, ack) self.channel = UTCPChannel(self) except socket.error as error: self.own_socket.close() @@ -385,15 +393,13 @@ class TCP(object): else: raise EOFError('Connection not in connected devices') fin = Fin() - packet_to_send = pickle.dumps(fin) - self.own_socket.sendto(packet_to_send, connection) - answer = self.retransmit(connection, packet_to_send, want_id=fin.id) + answer = self.__send_packet(connection, fin, retransmit=True) answer = self.find_correct_packet('FIN-ACK', connection, want_id=fin.id) if not answer: raise Exception('The receiver didn\'t send the fin packet') else: ack = Ack(fin.id) - self.send_ack(connection, ack) + self.__send_packet(connection, ack) self.drop_connection(connection) if len(self.connections) == 0 and self.client: self.own_socket.close() @@ -404,11 +410,10 @@ class TCP(object): def disconnect(self, connection, fin_id): try: ack = Ack(fin_id) - self.send_ack(connection, ack) + self.__send_packet(connection, ack) fin_ack = FinAck(fin_id) - packet_to_send = pickle.dumps(fin_ack) try: - answer = self.retransmit(connection, packet_to_send, want_id=fin_id) + answer = self.__send_packet(connection, fin_ack, retransmit=True) except: pass with self.connection_lock: @@ -420,7 +425,6 @@ class TCP(object): def data_divider(data): '''Divides the data into a list where each element's length is 1024''' data = [data[i:i + DATA_DIVIDE_LENGTH] for i in range(0, len(data), DATA_DIVIDE_LENGTH)] - data.append(PACKET_END) return data @staticmethod @@ -440,14 +444,30 @@ class TCP(object): if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']: tries += 1 if condition == 'DATA': - with conn.recv_lock: - packet = self.packets_received[condition][address][:size] - self.packets_received[condition][address] = self.packets_received[condition][address][size:] - if not len(self.packets_received[condition][address]): - del self.packets_received[condition][address] + if len(conn.packet_buffer[condition]): + data = b'' + while size and conn.packet_buffer[condition]: + with conn.recv_lock: + packet = conn.packet_buffer[condition][0] + chunk = packet.data[:size] + chunk_len = len(chunk) + data += chunk + packet.data = packet.data[size:] + size -= chunk_len + if not len(packet.data): + try: + conn.packet_buffer[condition].pop(0) + except IndexError: + size = 0 + return data + else: + raise KeyError else: with conn.recv_lock: - packet = self.packets_received[condition].pop(address) + if len(conn.packet_buffer[condition]): + packet = conn.packet_buffer[condition].pop() + else: + raise KeyError if want_id and packet.id != want_id: raise KeyError return packet @@ -463,31 +483,35 @@ class TCP(object): def sort_answers(self, packet, address): if address not in self.connections and not isinstance(packet, Syn): return - if isinstance(packet, Syn): - with self.queue_lock: - pass + if not isinstance(packet, Syn): + conn = self.connections[address] + if conn.recv_seq == packet.seq: + ack = Ack(packet.id) + self.__send_packet(address, ack) + return + elif conn.recv_seq > packet.seq: + return + else: + conn.recv_seq = packet.seq + if isinstance(packet, Packet): + if packet.checksum == TCP.checksum(packet.data): + if self.encrypted: + packet.data = conn.my_key.decrypt_raw(packet.data) + else: + return if isinstance(packet, Fin): self.disconnect(address, packet.id) - elif isinstance(packet, Packet): - if packet.checksum == TCP.checksum(packet.data): - conn = self.connections[address] - data_chunk = packet.data if not self.encrypted else conn.my_key.decrypt_raw(packet.data) - if data_chunk != PACKET_END: - with conn.recv_lock: - if address not in self.packets_received['DATA']: - self.packets_received['DATA'][address] = b'' - self.packets_received['DATA'][address] += data_chunk - ack = Ack(packet.id) - self.send_ack(address, ack) - self.blink_incoming_packet_event() + elif isinstance(packet, Syn): + if packet.id not in map(lambda x: x.id, self.syn_received.values()): + self.syn_received[address] = packet else: - if address in self.packets_received[packet.type]: - conn = self.connections[address] - with conn.recv_lock: - if packet.id == self.packets_received[packet.type][address].id: - return - self.packets_received[packet.type][address] = packet - self.blink_incoming_packet_event() + with conn.recv_lock: + if packet.id not in map(lambda x: x.id, conn.packet_buffer[packet.type]): + bisect.insort(conn.packet_buffer[packet.type], packet) + if isinstance(packet, Packet): + ack = Ack(packet.id) + self.__send_packet(address, ack) + self.blink_incoming_packet_event() def central_receive_handler(self): while True and self.status: