diff --git a/mods/rpyc_utcp.py b/mods/rpyc_utcp.py index e68f130..daf9725 100644 --- a/mods/rpyc_utcp.py +++ b/mods/rpyc_utcp.py @@ -15,7 +15,7 @@ class UTCPSocketStream(SocketStream): MAX_IO_CHUNK = utcp.DATA_LENGTH @classmethod def utcp_connect(cls, host, port, *a, **kw): - sock = utcp.TCP(encrypted=True) + sock = utcp.UTCP(encrypted=True) sock.connect((host, port)) return cls(sock) def poll(self, timeout): @@ -54,7 +54,7 @@ class UTCPThreadedServer(ThreadedServer): self.listener.close() self.listener = None ########## - self.listener = utcp.TCP(encrypted=True) + self.listener = utcp.UTCP(encrypted=True) self.listener.bind((hostname, port)) sockname = self.listener.getsockname() self.host, self.port = sockname[0], sockname[1] diff --git a/mods/utcp.py b/mods/utcp.py index 7d6eea6..6b36adf 100644 --- a/mods/utcp.py +++ b/mods/utcp.py @@ -1,4 +1,4 @@ -# based on https://github.com/ethay012/TCP-over-UDP +# based on https://github.com/ethay012/UTCP-over-UDP import random import socket import pickle @@ -14,9 +14,6 @@ DATA_DIVIDE_LENGTH = 8000 PACKET_HEADER_SIZE = 512 # Pickle service info DATA_LENGTH = DATA_DIVIDE_LENGTH SENT_SIZE = PACKET_HEADER_SIZE + DATA_LENGTH + 272 # Encrypted data always 272 bytes bigger -LAST_CONNECTION = -1 -FIRST = 0 -PACKET_END = b'___+++^^^END^^^+++___' class Connection: SMALLEST_STARTING_SEQ = 0 @@ -24,7 +21,7 @@ class Connection: def __init__(self, remote, encrypted=False): self.peer_addr = remote self.seq = Connection.gen_starting_seq_num() - self.recv_seq = 0 + self.recv_seq = -1 self.my_key = None if encrypted: self.my_key = simplecrypto.RsaKeypair() @@ -80,7 +77,7 @@ class Syn(UTCPPacket): self.id = uuid.uuid4().bytes self.seq = 0 def set_pub(self, pubkey): - self.checksum = TCP.checksum(pubkey) + self.checksum = UTCP.checksum(pubkey) self.pubkey = pubkey class SynAck(UTCPPacket): type = 'SYN-ACK' @@ -89,7 +86,7 @@ class SynAck(UTCPPacket): self.id = id_ self.seq = 0 def set_pub(self, pubkey): - self.checksum = TCP.checksum(pubkey) + self.checksum = UTCP.checksum(pubkey) self.pubkey = pubkey class Packet(UTCPPacket): @@ -106,7 +103,7 @@ class Packet(UTCPPacket): return 'UUID: %d, DATA:%s' % (uuid.UUID(bytes=self.id), self.data) def set_data(self, data): - self.checksum = TCP.checksum(data) + self.checksum = UTCP.checksum(data) self.data = data @@ -127,7 +124,6 @@ def restricted_pickle_loads(s): class UTCPChannel: HEADER = Struct('!I') - TERMINATOR = b'\n' POLL_TIMEOUT = 0.1 def __init__(self, sock): self.sock = sock @@ -136,11 +132,10 @@ class UTCPChannel: def recv(self): header = self.sock.recv(self.HEADER.size) data_len = self.HEADER.unpack(header)[0] - terminated_data = self.sock.recv(data_len + len(self.TERMINATOR)) - return terminated_data[:-len(self.TERMINATOR)] + return self.sock.recv(data_len) def send(self, data): header = self.HEADER.pack(len(data)) - self.sock.send(header + data + self.TERMINATOR) + self.sock.send(header + data) class ConnectedSOCK(object): def __init__(self, low_sock, client_addr): @@ -183,15 +178,16 @@ class ConnectedSOCK(object): return len(conn.packet_buffer['DATA']) return False -class TCP(object): +class UTCP(object): host = None port = None client = False - def __init__(self, af_type=None, sock_type=None, encrypted=False): + def __init__(self,encrypted=False, **kw): self.encrypted = encrypted self.incoming_packet_event = threading.Event() self.new_conn_event = threading.Event() self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.own_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.settimeout() self.connection_lock = threading.Lock() self.queue_lock = threading.Lock() @@ -204,13 +200,13 @@ class TCP(object): connection = list(self.connections.keys())[0] conn = self.connections[connection] with conn.recv_lock: - has_data = len(conn.packet_buffer['DATA']) + has_data = bool(len(conn.packet_buffer['DATA'])) if has_data: return True else: self.incoming_packet_event.wait(timeout) with conn.recv_lock: - return len(conn.packet_buffer['DATA']) + return bool(len(conn.packet_buffer['DATA'])) return False def get_free_port(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -225,7 +221,7 @@ class TCP(object): def setblocking(self, mode): self.own_socket.setblocking(mode) def __repr__(self): - return 'TCP()' + return 'UTCP()' def __str__(self): return 'Connections: %s' \ @@ -251,12 +247,12 @@ class TCP(object): else: raise EOFError('Connection not in connected devices') conn = self.connections[connection] - data_parts = TCP.data_divider(data) + data_parts = UTCP.data_divider(data) for data_part in data_parts: data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part) packet = Packet() packet.set_data(data_chunk) - answer = self.__send_packet(connection, packet, retransmit=True) + self.__send_packet(connection, packet, retransmit=True) return len(data) except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) @@ -341,6 +337,7 @@ class TCP(object): if self.connection_queue: with self.queue_lock: answer, conn = self.connection_queue.pop() + conn.recv_seq = answer.seq self.connections[conn.peer_addr] = conn syn_ack = SynAck(answer.id) if self.encrypted: @@ -410,7 +407,7 @@ class TCP(object): raise Exception('The receiver didn\'t send the fin packet') else: self.drop_connection(connection) - if len(self.connections) == 0 and self.client: + if not len(self.connections) and self.client: self.stop() except Exception as error: raise EOFError('Something went wrong in the close func! Error is: %s.' % error) @@ -421,7 +418,7 @@ class TCP(object): self.__send_packet(connection, ack) fin_ack = FinAck(fin_id) try: - answer = self.__send_packet(connection, fin_ack) + self.__send_packet(connection, fin_ack) except: pass with self.connection_lock: @@ -454,7 +451,9 @@ class TCP(object): if condition == 'DATA': if len(conn.packet_buffer[condition]): data = b'' - while size and conn.packet_buffer[condition]: + while size: + if not self.poll(0.5): + continue with conn.recv_lock: packet = conn.packet_buffer[condition][0] chunk = packet.data[:size] @@ -488,21 +487,25 @@ class TCP(object): def blink_new_conn_event(self): self.new_conn_event.set() self.new_conn_event.clear() - def sort_answers(self, packet, address): + def multiplex(self, packet, address): if address not in self.connections and not isinstance(packet, Syn): return if not isinstance(packet, Syn): conn = self.connections[address] - if conn.recv_seq == packet.seq: + if isinstance(packet, SynAck): + conn.recv_seq = packet.seq + elif conn.recv_seq == packet.seq: # Repeat ACK ack = Ack(packet.id) self.__send_packet(address, ack) return - elif conn.recv_seq > packet.seq: + elif packet.seq < conn.recv_seq: # Possibly DUP + return + elif packet.seq > (conn.recv_seq + 1): # Intermediate packet lost return else: conn.recv_seq = packet.seq if isinstance(packet, Packet): - if packet.checksum == TCP.checksum(packet.data): + if packet.checksum == UTCP.checksum(packet.data): if self.encrypted: packet.data = conn.my_key.decrypt_raw(packet.data) else: @@ -526,12 +529,12 @@ class TCP(object): try: packet, address = self.own_socket.recvfrom(SENT_SIZE) packet = restricted_pickle_loads(packet) - self.sort_answers(packet, address) + self.multiplex(packet, address) except pickle.UnpicklingError: continue except socket.timeout: continue - except socket.error as error: + except socket.error: self.own_socket.close() self.status = 0