# based on https://github.com/ethay012/TCP-over-UDP import random import socket import pickle import threading import io import hashlib import simplecrypto from struct import Struct import uuid import bisect DATA_DIVIDE_LENGTH = 8000 PACKET_HEADER_SIZE = 512 # Pickle service info DATA_LENGTH = DATA_DIVIDE_LENGTH SENT_SIZE = PACKET_HEADER_SIZE + DATA_LENGTH + 272 # Encrypted data always 272 bytes bigger LAST_CONNECTION = -1 FIRST = 0 PACKET_END = b'___+++^^^END^^^+++___' class Connection: SMALLEST_STARTING_SEQ = 0 HIGHEST_STARTING_SEQ = 4294967295 def __init__(self, remote, encrypted=False): self.peer_addr = remote self.seq = Connection.gen_starting_seq_num() self.recv_seq = 0 self.my_key = None if encrypted: self.my_key = simplecrypto.RsaKeypair() self.pubkey = self.my_key.publickey.serialize() self.peer_pub = None self.recv_lock = threading.Lock() self.send_lock = threading.Lock() self.packet_buffer = { 'ACK': [], 'SYN-ACK': [], 'DATA': [], 'FIN-ACK': [] } @staticmethod def gen_starting_seq_num(): return random.randint(Connection.SMALLEST_STARTING_SEQ, Connection.HIGHEST_STARTING_SEQ) def seq_inc(self, inc=1): self.seq += inc return self.seq def set_ack(self, ack): self.ack = ack return ack class UTCPPacket: def __cmp__(self, other): return (self.seq > other.seq) - (self.seq < other.seq) def __lt__(self, other): return self.seq < other.seq def __gt__(self, other): return self.seq > other.seq def __eq__(self, other): return self.seq == other.seq class Ack(UTCPPacket): type = 'ACK' def __init__(self, id_): self.id = id_ self.seq = 0 class Fin(UTCPPacket): type = 'FIN' def __init__(self): self.id = uuid.uuid4().bytes self.seq = 0 class FinAck(UTCPPacket): type = 'FIN-ACK' def __init__(self, id_): self.id = id_ self.seq = 0 class Syn(UTCPPacket): type = 'SYN' checksum = None def __init__(self): self.id = uuid.uuid4().bytes self.seq = 0 def set_pub(self, pubkey): self.checksum = TCP.checksum(pubkey) self.pubkey = pubkey class SynAck(UTCPPacket): type = 'SYN-ACK' checksum = None def __init__(self, id_): self.id = id_ self.seq = 0 def set_pub(self, pubkey): self.checksum = TCP.checksum(pubkey) self.pubkey = pubkey class Packet(UTCPPacket): type = 'DATA' def __init__(self): self.id = uuid.uuid4().bytes self.checksum = 0 self.data = b'' self.seq = 0 def __repr__(self): return f'Packet(seq={self.seq})' def __str__(self): return 'UUID: %d, DATA:%s' % (uuid.UUID(bytes=self.id), self.data) def set_data(self, data): self.checksum = TCP.checksum(data) self.data = data class RestrictedUnpickler(pickle.Unpickler): def find_class(self, module, name): if module != 'builtins': if name == 'Packet': return Packet if name == 'Ack': return Ack if name == 'Fin': return Fin if name == 'FinAck': return FinAck if name == 'Syn': return Syn if name == 'SynAck': return SynAck raise pickle.UnpicklingError("global '%s.%s' is forbidden" % (module, name)) def restricted_pickle_loads(s): return RestrictedUnpickler(io.BytesIO(s)).load() class UTCPChannel: HEADER = Struct('!I') TERMINATOR = b'\n' POLL_TIMEOUT = 0.1 def __init__(self, sock): self.sock = sock def poll(self): return self.sock.poll(self.POLL_TIMEOUT) def recv(self): header = self.sock.recv(self.HEADER.size) data_len = self.HEADER.unpack(header)[0] terminated_data = self.sock.recv(data_len + len(self.TERMINATOR)) return terminated_data[:-len(self.TERMINATOR)] def send(self, data): header = self.HEADER.pack(len(data)) self.sock.send(header + data + self.TERMINATOR) class ConnectedSOCK(object): def __init__(self, low_sock, client_addr): self.client_addr = client_addr self.low_sock = low_sock self.channel = UTCPChannel(self) def __getattribute__(self, att): try: return object.__getattribute__(self, att) except AttributeError: return getattr(self.low_sock, att) def getpeername(self): return self.client_addr def send(self, data): if self.closed: raise EOFError return self.low_sock.send(data, self.client_addr) def recv(self, size): if self.closed: raise EOFError return self.low_sock.recv(size, self.client_addr) @property def closed(self): return self.low_sock.own_socket._closed or self.client_addr not in self.low_sock.connections def close(self): if self.client_addr in self.low_sock.connections: self.low_sock.close(self.client_addr) def shutdown(self, *a, **kw): self.close() def poll(self, timeout): if not self.closed: conn = self.low_sock.connections[self.client_addr] with conn.recv_lock: has_data = len(conn.packet_buffer['DATA']) if has_data: return True else: self.incoming_packet_event.wait(timeout) with conn.recv_lock: return len(conn.packet_buffer['DATA']) return False class TCP(object): host = None port = None client = False def __init__(self, af_type=None, sock_type=None, encrypted=False): self.encrypted = encrypted self.incoming_packet_event = threading.Event() self.new_conn_event = threading.Event() self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.settimeout() self.connection_lock = threading.Lock() self.queue_lock = threading.Lock() self.channel = None self.connections = {} self.connection_queue = [] self.syn_received = {} def poll(self, timeout): if len(self.connections): connection = list(self.connections.keys())[0] conn = self.connections[connection] with conn.recv_lock: has_data = len(conn.packet_buffer['DATA']) if has_data: return True else: self.incoming_packet_event.wait(timeout) with conn.recv_lock: return len(conn.packet_buffer['DATA']) return False def get_free_port(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.bind(('', 0)) port = s.getsockname()[1] s.close() return port def setsockopt(self, *a, **kw): pass def settimeout(self, timeout=5): self.own_socket.settimeout(timeout) def setblocking(self, mode): self.own_socket.setblocking(mode) def __repr__(self): return 'TCP()' def __str__(self): return 'Connections: %s' \ % str(self.connections) def getsockname(self): return self.own_socket.getsockname() def getpeername(self): if len(self.connections): return list(self.connections.keys())[0] else: raise EOFError('Not connected') def bind(self, addr): self.host = addr[0] self.port = addr[1] self.own_socket.bind(addr) def send(self, data, connection=None): if self.closed: raise EOFError try: if connection not in list(self.connections.keys()): if connection is None: connection = list(self.connections.keys())[0] else: raise EOFError('Connection not in connected devices') conn = self.connections[connection] data_parts = TCP.data_divider(data) for data_part in data_parts: data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part) packet = Packet() packet.set_data(data_chunk) answer = self.__send_packet(connection, packet, retransmit=True) return len(data) except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) def recv(self, size, connection=None): if self.closed: raise EOFError try: if connection not in list(self.connections.keys()): if connection is None: connection = list(self.connections.keys())[0] else: raise EOFError('Connection not in connected devices') data = self.find_correct_packet('DATA', connection, size) if not self.status: raise EOFError('Disconnecting') return data except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) def __send_packet(self, peer_addr, packet, retransmit=False, wait_cond='ACK'): conn = self.connections[peer_addr] packet.seq = conn.seq_inc() packet_to_send = pickle.dumps(packet) if not retransmit: self.own_socket.sendto(packet_to_send, peer_addr) else: data_not_received = True retransmit_count = 0 while data_not_received and retransmit_count < 3: data_not_received = False try: self.own_socket.sendto(packet_to_send, peer_addr) answer = self.find_correct_packet(wait_cond, peer_addr, want_id=packet.id) if not answer: data_not_received = True retransmit_count += 1 except socket.timeout: data_not_received = True if not answer: self.drop_connection(peer_addr) raise EOFError('Connection lost') return answer def listen_handler(self, max_connections): try: while True and self.status: try: answer, address = self.find_correct_packet('SYN') with self.queue_lock: if len(self.connection_queue) < max_connections: conn = Connection(address, self.encrypted) if self.encrypted: try: conn.peer_pub = simplecrypto.RsaPublicKey(answer.pubkey) except: raise socket.error('Init peer public key error') self.connection_queue.append((answer, conn)) self.blink_new_conn_event() else: if answer.id in map(lambda x: x[0].id, self.connection_queue): continue self.own_socket.sendto(b'Connections full', address) except KeyError: continue except socket.error as error: raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error)) def listen(self, max_connections=1): self.status = 1 self.central_receive() try: t = threading.Thread(target=self.listen_handler, args=(max_connections,)) t.daemon = True t.start() except Exception as error: raise EOFError('Something went wrong in listen func! Error is: %s.' % str(error)) def stop(self): self.own_socket.close() self.status = 0 def accept(self): while self.status: try: self.new_conn_event.wait(0.1) if self.connection_queue: with self.queue_lock: answer, conn = self.connection_queue.pop() self.connections[conn.peer_addr] = conn syn_ack = SynAck(answer.id) if self.encrypted: syn_ack.set_pub(conn.peer_pub.encrypt_raw(conn.pubkey)) answer = self.__send_packet(conn.peer_addr, syn_ack, retransmit=True) return ConnectedSOCK(self, conn.peer_addr), conn.peer_addr except EOFError: if conn.peer_addr in self.connections: self.close(conn.peer_addr) continue except Exception as error: if conn.peer_addr in self.connections: self.close(conn.peer_addr) raise EOFError('Something went wrong in accept func: ' + str(error)) def connect(self, server_address=('127.0.0.1', 10000)): try: self.bind(('', self.get_free_port())) self.status = 1 self.client = True self.central_receive() conn = Connection(server_address, self.encrypted) self.connections[server_address] = conn syn = Syn() if self.encrypted: syn.set_pub(conn.pubkey) try: answer = self.__send_packet(server_address, syn, retransmit=True, wait_cond='SYN-ACK') except EOFError: raise EOFError('Remote peer unreachable') if type(answer) == str: # == 'Connections full': raise socket.error('Server cant receive any connections right now.') if self.encrypted: try: peer_pub = conn.my_key.decrypt_raw(answer.pubkey) conn.peer_pub = simplecrypto.RsaPublicKey(peer_pub) except: raise socket.error('Decrypt peer public key error') ack = Ack(answer.id) self.__send_packet(server_address, ack) self.channel = UTCPChannel(self) except socket.error as error: self.own_socket.close() self.connections = {} self.status = 0 raise EOFError('The socket was closed. Error:' + str(error)) def fileno(self): return self.own_socket.fileno() @property def closed(self): return not bool(len(self.connections)) def drop_connection(self, connection): with self.connection_lock: if len(self.connections): self.connections.pop(connection) def close(self, connection=None): try: if connection not in list(self.connections.keys()): if connection is None: connection = list(self.connections.keys())[0] else: raise EOFError('Connection not in connected devices') fin = Fin() answer = self.__send_packet(connection, fin, retransmit=True) answer = self.find_correct_packet('FIN-ACK', connection, want_id=fin.id) if not answer: raise Exception('The receiver didn\'t send the fin packet') else: self.drop_connection(connection) if len(self.connections) == 0 and self.client: self.stop() except Exception as error: raise EOFError('Something went wrong in the close func! Error is: %s.' % error) def disconnect(self, connection, fin_id): try: ack = Ack(fin_id) self.__send_packet(connection, ack) fin_ack = FinAck(fin_id) try: answer = self.__send_packet(connection, fin_ack) except: pass with self.connection_lock: self.connections.pop(connection) except Exception as error: raise EOFError('Something went wrong in disconnect func:%s ' % error) @staticmethod def data_divider(data): '''Divides the data into a list where each element's length is 1024''' data = [data[i:i + DATA_DIVIDE_LENGTH] for i in range(0, len(data), DATA_DIVIDE_LENGTH)] return data @staticmethod def checksum(source_bytes): return hashlib.sha1(source_bytes).digest() def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH, want_id=None): not_found = True tries = 0 while not_found and tries < 2 and self.status: try: not_found = False if address[0] == 'Any': order = self.syn_received.popitem() # to reverse the tuple received return order[1], order[0] conn = self.connections[address] if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']: tries += 1 if condition == 'DATA': if len(conn.packet_buffer[condition]): data = b'' while size and conn.packet_buffer[condition]: with conn.recv_lock: packet = conn.packet_buffer[condition][0] chunk = packet.data[:size] chunk_len = len(chunk) data += chunk packet.data = packet.data[size:] size -= chunk_len if not len(packet.data): try: conn.packet_buffer[condition].pop(0) except IndexError: size = 0 return data else: raise KeyError else: with conn.recv_lock: if len(conn.packet_buffer[condition]): packet = conn.packet_buffer[condition].pop() else: raise KeyError if want_id and packet.id != want_id: raise KeyError return packet except KeyError: not_found = True self.incoming_packet_event.wait(0.5) def blink_incoming_packet_event(self): self.incoming_packet_event.set() self.incoming_packet_event.clear() def blink_new_conn_event(self): self.new_conn_event.set() self.new_conn_event.clear() def sort_answers(self, packet, address): if address not in self.connections and not isinstance(packet, Syn): return if not isinstance(packet, Syn): conn = self.connections[address] if conn.recv_seq == packet.seq: ack = Ack(packet.id) self.__send_packet(address, ack) return elif conn.recv_seq > packet.seq: return else: conn.recv_seq = packet.seq if isinstance(packet, Packet): if packet.checksum == TCP.checksum(packet.data): if self.encrypted: packet.data = conn.my_key.decrypt_raw(packet.data) else: return if isinstance(packet, Fin): self.disconnect(address, packet.id) elif isinstance(packet, Syn): if packet.id not in map(lambda x: x.id, self.syn_received.values()): self.syn_received[address] = packet else: with conn.recv_lock: if packet.id not in map(lambda x: x.id, conn.packet_buffer[packet.type]): bisect.insort(conn.packet_buffer[packet.type], packet) if isinstance(packet, Packet): ack = Ack(packet.id) self.__send_packet(address, ack) self.blink_incoming_packet_event() def central_receive_handler(self): while True and self.status: try: packet, address = self.own_socket.recvfrom(SENT_SIZE) packet = restricted_pickle_loads(packet) self.sort_answers(packet, address) except pickle.UnpicklingError: continue except socket.timeout: continue except socket.error as error: self.own_socket.close() self.status = 0 def central_receive(self): t = threading.Thread(target=self.central_receive_handler) t.daemon = True t.start()