diff --git a/mods/utcp.py b/mods/utcp.py index cba1c83..647f42f 100644 --- a/mods/utcp.py +++ b/mods/utcp.py @@ -19,12 +19,6 @@ PACKET_END = b'___+++^^^END^^^+++___' AF_INET = None SOCK_STREAM = None -class KeyPair: - my_key = None - peer_pub = None - def __init__(self, sec): - self.my_key = sec - class Connection: SMALLEST_STARTING_SEQ = 0 HIGHEST_STARTING_SEQ = 4294967295 @@ -34,7 +28,7 @@ class Connection: self.seq = Connection.gen_starting_seq_num() self.my_key = None if encrypted: - self.my_key = KeyPair(simplecrypto.RsaKeypair()) + self.my_key = simplecrypto.RsaKeypair() self.pubkey = self.my_key.publickey.serialize() self.peer_pub = None self.recv_lock = threading.Lock() @@ -55,6 +49,7 @@ class TCPPacket(object): self.ack = 0 self.flag_ack = 0 self.flag_syn = 0 + self.flag_fin = 0 self.checksum = 0 self.data = b'' def __repr__(self): @@ -206,18 +201,15 @@ class TCP(object): connection = list(self.connections.keys())[0] else: raise EOFError('Connection not in connected devices') + conn = self.connections[connection] data_parts = TCP.data_divider(data) for data_part in data_parts: - data_chunk = data_part if not self.encrypted else self.peer_keypair[connection].peer_pub.encrypt_raw(data_part) - checksum_of_data = TCP.checksum(data_chunk) - self.connections[connection].checksum = checksum_of_data - self.connections[connection].data = data_chunk - self.connections[connection].set_flags() - packet_to_send = pickle.dumps(self.connections[connection]) - self.connections[connection].checksum = 0 - self.connections[connection].data = b'' + data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part) + packet = TCPPacket(conn.seq) + packet.set_data(data_chunk) + packet_to_send = pickle.dumps(packet) answer = self.retransmit(connection, packet_to_send) - self.connections[connection].seq += len(data_part) + conn.seq_inc(len(data_part)) return len(data) except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) @@ -269,9 +261,8 @@ class TCP(object): conn = Connection(address, self.encrypted) if self.encrypted: try: - conn.peer_pub = answer.data + conn.peer_pub = simplecrypto.RsaPublicKey(answer.data) except: - self.peer_keypair.pop(address) raise socket.error('Init peer public key error') self.connection_queue.append((answer, conn)) self.blink_new_conn_event() @@ -329,7 +320,7 @@ class TCP(object): self.client = True self.central_receive() conn = Connection(server_address, self.encrypted) - self.connections[server_address] = Connection(server_address, self.encrypted) + self.connections[server_address] = conn syn_packet = TCPPacket(conn.seq) syn_packet.set_flags(syn=True) if self.encrypted: @@ -354,7 +345,6 @@ class TCP(object): except socket.error as error: self.own_socket.close() self.connections = {} - self.peer_keypair = {} self.status = 0 raise EOFError('The socket was closed. Error:' + str(error)) def fileno(self): @@ -366,8 +356,6 @@ class TCP(object): with self.connection_lock: if len(self.connections): self.connections.pop(connection) - if len(self.peer_keypair): - self.peer_keypair.pop(connection) def close(self, connection=None): try: if connection not in list(self.connections.keys()): @@ -430,14 +418,15 @@ class TCP(object): if address[0] == 'Any': order = self.packets_received[condition].popitem() # to reverse the tuple received return order[1], order[0] + conn = self.connections[address] if condition == 'ACK': tries += 1 if condition == 'DATA or FIN': - with self.connection_lock: + with conn.recv_lock: packet = self.packets_received[condition][address][:size] self.packets_received[condition][address] = self.packets_received[condition][address][size:] - if not len(self.packets_received[condition][address]): - del self.packets_received[condition][address] + if not len(self.packets_received[condition][address]): + del self.packets_received[condition][address] else: packet = self.packets_received[condition].pop(address) return packet @@ -457,9 +446,10 @@ class TCP(object): self.disconnect(address) elif packet.packet_type() == 'DATA': if packet.checksum == TCP.checksum(packet.data): - data_chunk = packet.data if not self.encrypted else self.peer_keypair[address].my_key.decrypt_raw(packet.data) + conn = self.connections[address] + data_chunk = packet.data if not self.encrypted else conn.my_key.decrypt_raw(packet.data) if data_chunk != PACKET_END: - with self.connection_lock: + with conn.recv_lock: if address not in self.packets_received['DATA or FIN']: self.packets_received['DATA or FIN'][address] = b'' self.packets_received['DATA or FIN'][address] += data_chunk