diff --git a/mods/utcp.py b/mods/utcp.py index b13acd1..21053a5 100644 --- a/mods/utcp.py +++ b/mods/utcp.py @@ -8,6 +8,7 @@ import hashlib import simplecrypto from struct import Struct import uuid +import bisect DATA_DIVIDE_LENGTH = 8000 PACKET_HEADER_SIZE = 512 # Pickle service info @@ -17,16 +18,11 @@ LAST_CONNECTION = -1 FIRST = 0 PACKET_END = b'___+++^^^END^^^+++___' -# need for emulate -AF_INET = None -SOCK_STREAM = None - class Connection: SMALLEST_STARTING_SEQ = 0 HIGHEST_STARTING_SEQ = 4294967295 def __init__(self, remote, encrypted=False): self.peer_addr = remote - self.ack = 0 self.seq = Connection.gen_starting_seq_num() self.my_key = None if encrypted: @@ -36,7 +32,6 @@ class Connection: self.recv_lock = threading.Lock() self.send_lock = threading.Lock() self.packet_buffer = { - 'SYN': [], 'ACK': [], 'SYN-ACK': [], 'DATA': [], @@ -52,43 +47,50 @@ class Connection: self.ack = ack return ack -class Ack: +class UTCPPacket: + def __cmp__(self, other): + return (self.seq > other.seq) - (self.seq < other.seq) + +class Ack(UTCPPacket): type = 'ACK' - def __init__(self, id_): + def __init__(self, id_, seq): self.id = id_ -class Fin: + self.seq = seq +class Fin(UTCPPacket): type = 'FIN' - def __init__(self): + def __init__(self, seq): self.id = uuid.uuid4().bytes -class FinAck: + self.seq = seq +class FinAck(UTCPPacket): type = 'FIN-ACK' - def __init__(self, id_): + def __init__(self, id_, seq): self.id = id_ -class Syn: + self.seq = seq +class Syn(UTCPPacket): type = 'SYN' checksum = None - def __init__(self): + def __init__(self, seq): self.id = uuid.uuid4().bytes + self.seq = seq def set_pub(self, pubkey): self.checksum = TCP.checksum(pubkey) self.pubkey = pubkey -class SynAck: +class SynAck(UTCPPacket): type = 'SYN-ACK' checksum = None - def __init__(self, id_): + def __init__(self, id_, seq): self.id = id_ + self.seq = seq def set_pub(self, pubkey): self.checksum = TCP.checksum(pubkey) self.pubkey = pubkey -class Packet: - def __init__(self): +class Packet(UTCPPacket): + def __init__(self, seq): self.id = uuid.uuid4().bytes - self.flag_ack = 0 - self.flag_syn = 0 - self.flag_fin = 0 self.checksum = 0 self.data = b'' + self.seq = seq def __repr__(self): return f'TCPpacket(type={self.packet_type()})' @@ -164,13 +166,13 @@ class ConnectedSOCK(object): if not self.closed: conn = self.low_sock.connections[self.client_addr] with conn.recv_lock: - has_data = self.client_addr in self.packets_received['DATA or FIN'] + has_data = len(conn.packet_buffer['DATA']) if has_data: return True else: self.incoming_packet_event.wait(timeout) with conn.recv_lock: - return self.client_addr in self.packets_received['DATA or FIN'] + return len(conn.packet_buffer['DATA']) return False class TCP(object): @@ -188,19 +190,19 @@ class TCP(object): self.channel = None self.connections = {} self.connection_queue = [] - self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} + self.syn_received = {} def poll(self, timeout): if len(self.connections): connection = list(self.connections.keys())[0] conn = self.connections[connection] with conn.recv_lock: - has_data = connection in self.packets_received['DATA or FIN'] + has_data = len(conn.packet_buffer['DATA']) if has_data: return True else: self.incoming_packet_event.wait(timeout) with conn.recv_lock: - return connection in self.packets_received['DATA or FIN'] + return len(conn.packet_buffer['DATA']) return False def get_free_port(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -274,7 +276,7 @@ class TCP(object): else: raise EOFError('Connection not in connected devices') - data = self.find_correct_packet('DATA or FIN', connection, size) + data = self.find_correct_packet('DATA', connection, size) if not self.status: raise EOFError('Disconnecting') return data @@ -375,8 +377,6 @@ class TCP(object): with self.connection_lock: if len(self.connections): self.connections.pop(connection) - for k in list(self.packets_received.keys()): - self.packets_received[k].pop(connection) def close(self, connection=None): try: if connection not in list(self.connections.keys()): @@ -434,12 +434,12 @@ class TCP(object): try: not_found = False if address[0] == 'Any': - order = self.packets_received[condition].popitem() # to reverse the tuple received + order = self.syn_received.popitem() # to reverse the tuple received return order[1], order[0] conn = self.connections[address] if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']: tries += 1 - if condition == 'DATA or FIN': + if condition == 'DATA': with conn.recv_lock: packet = self.packets_received[condition][address][:size] self.packets_received[condition][address] = self.packets_received[condition][address][size:] @@ -474,9 +474,9 @@ class TCP(object): data_chunk = packet.data if not self.encrypted else conn.my_key.decrypt_raw(packet.data) if data_chunk != PACKET_END: with conn.recv_lock: - if address not in self.packets_received['DATA or FIN']: - self.packets_received['DATA or FIN'][address] = b'' - self.packets_received['DATA or FIN'][address] += data_chunk + if address not in self.packets_received['DATA']: + self.packets_received['DATA'][address] = b'' + self.packets_received['DATA'][address] += data_chunk ack = Ack(packet.id) self.send_ack(address, ack) self.blink_incoming_packet_event()