From 586e69552bb8135f3bff796c83e09342efa60516 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A0=D0=BE=D0=BC=D0=B0=D0=BD=20=D0=91=D0=BE=D1=80=D0=BE?= =?UTF-8?q?=D0=B4=D0=B8=D0=BD?= Date: Mon, 8 Apr 2019 09:22:52 +0300 Subject: [PATCH] . --- mods/utcp.py | 82 ++++++++++++++++++++++++---------------------------- 1 file changed, 38 insertions(+), 44 deletions(-) diff --git a/mods/utcp.py b/mods/utcp.py index 6d63d30..f3f4afb 100644 --- a/mods/utcp.py +++ b/mods/utcp.py @@ -30,8 +30,15 @@ class Connection: HIGHEST_STARTING_SEQ = 4294967295 def __init__(self, remote, encrypted=False): self.peer_addr = remote + self.ack = 0 self.seq = Connection.gen_starting_seq_num() - self.my_key + self.my_key = None + if encrypted: + self.my_key = KeyPair(simplecrypto.RsaKeypair()) + self.pubkey = self.my_key.publickey.serialize() + self.peer_pub = None + self.recv_lock = threading.Lock() + self.send_lock = threading.Lock() @staticmethod def gen_starting_seq_num(): return random.randint(Connection.SMALLEST_STARTING_SEQ, Connection.HIGHEST_STARTING_SEQ) @@ -86,6 +93,9 @@ class TCPPacket(object): self.flag_fin = 1 else: self.flag_fin = 0 + def set_data(self, data): + self.checksum = TCP.checksum(data) + self.data = data @@ -145,7 +155,6 @@ class TCP(object): self.settimeout() self.connection_lock = threading.Lock() self.queue_lock = threading.Lock() - self.peer_keypair = {} self.connections = {} self.connection_queue = [] self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} @@ -254,15 +263,14 @@ class TCP(object): answer, address = self.find_correct_packet('SYN') with self.queue_lock: if len(self.connection_queue) < max_connections: + conn = Connection(address, self.encrypted) if self.encrypted: try: - peer_pub = answer.data - self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key (~5 sec) - self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub) + conn.peer_pub = answer.data except: self.peer_keypair.pop(address) raise socket.error('Init peer public key error') - self.connection_queue.append((answer, address)) + self.connection_queue.append((answer, conn)) self.blink_new_conn_event() else: self.own_socket.sendto('Connections full', address) @@ -287,34 +295,28 @@ class TCP(object): self.new_conn_event.wait(0.1) if self.connection_queue: with self.queue_lock: - answer, address = self.connection_queue.pop() - - self.connections[address] = TCPPacket() - self.connections[address].ack = answer.seq + 1 - self.connections[address].seq += 1 - self.connections[address].set_flags(ack=True, syn=True) + answer, conn = self.connection_queue.pop() + self.connections[conn.peer_addr] = conn + packet = TCPPacket(conn.seq) + packet.ack = answer.seq + 1 + packet.seq = conn.seq_inc() + packet.set_flags(ack=True, syn=True) if self.encrypted: - pubkey = self.peer_keypair[address].my_key.publickey.serialize() - self.connections[address].data = self.peer_keypair[address].peer_pub.encrypt_raw(pubkey) - self.connections[address].checksum = TCP.checksum(self.connections[address].data) - packet_to_send = pickle.dumps(self.connections[address]) - if self.encrypted: - self.connections[address].data = b'' - self.connections[address].checksum = 0 - #lock address, connections dictionary? + packet.set_data(conn.peer_pub.encrypt_raw(conn.pubkey)) + packet_to_send = pickle.dumps(packet) + #On packet lost retransmit packet_not_sent_correctly = True while packet_not_sent_correctly or answer is None: try: packet_not_sent_correctly = False - self.own_socket.sendto(packet_to_send, address) - answer = self.find_correct_packet('ACK', address) + self.own_socket.sendto(packet_to_send, conn.peer_addr) + answer = self.find_correct_packet('ACK', conn.peer_addr) except socket.timeout: packet_not_sent_correctly = True - self.connections[address].set_flags() - self.connections[address].ack = answer.seq + 1 - return ConnectedSOCK(self, address), address + conn.ack = answer.seq + 1 + return ConnectedSOCK(self, conn.peer_addr), conn.peer_addr except Exception as error: - self.close(address) + self.close(conn.peer_addr) raise EOFError('Something went wrong in accept func: ' + str(error)) def connect(self, server_address=('127.0.0.1', 10000)): @@ -323,32 +325,24 @@ class TCP(object): self.status = 1 self.client = True self.central_receive() - self.connections[server_address] = TCPPacket() - self.connections[server_address].set_flags(syn=True) + conn = Connection(server_address, self.encrypted) + self.connections[server_address] = Connection(server_address, self.encrypted) + syn_packet = TCPPacket(conn.seq) + syn_packet.set_flags(syn=True) if self.encrypted: - self.peer_keypair[server_address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: but here it creates the key quickly - pubkey = self.peer_keypair[server_address].my_key.publickey.serialize() - pub_checksum = TCP.checksum(pubkey) - self.connections[server_address].checksum = pub_checksum - self.connections[server_address].data = pubkey - first_packet_to_send = pickle.dumps(self.connections[server_address]) - self.connections[server_address].data = b'' - self.connections[server_address].checksum = 0 - - self.own_socket.sendto(first_packet_to_send, list(self.connections.keys())[FIRST]) - self.connections[server_address].set_flags() + syn_packet.set_data(conn.pubkey) + first_packet_to_send = pickle.dumps(syn_packet) + self.own_socket.sendto(first_packet_to_send, server_address) answer = self.find_correct_packet('SYN-ACK', server_address) if type(answer) == str: # == 'Connections full': raise socket.error('Server cant receive any connections right now.') if self.encrypted: try: - peer_pub = self.peer_keypair[server_address].my_key.decrypt_raw(answer.data) - self.peer_keypair[server_address].peer_pub = simplecrypto.RsaPublicKey(peer_pub) + peer_pub = conn.my_key.decrypt_raw(answer.data) + conn.peer_pub = simplecrypto.RsaPublicKey(peer_pub) except: raise socket.error('Decrypt peer public key error') - if not peer_pub or answer.checksum != TCP.checksum(answer.data): - raise socket.error('Invalid peer public key') - + ack_packet = self.connections[server_address].ack = answer.seq + 1 self.connections[server_address].seq += 1 self.connections[server_address].set_flags(ack=True)