# based on https://github.com/ethay012/TCP-over-UDP import random import socket import pickle import threading import io import hashlib import simplecrypto DATA_DIVIDE_LENGTH = 8000 PACKET_HEADER_SIZE = 512 # Pickle service info DATA_LENGTH = DATA_DIVIDE_LENGTH SENT_SIZE = PACKET_HEADER_SIZE + DATA_LENGTH + 272 # Encrypted data always 272 bytes bigger LAST_CONNECTION = -1 FIRST = 0 PACKET_END = b'___+++^^^END^^^+++___' # need for emulate AF_INET = None SOCK_STREAM = None class KeyPair: my_key = None peer_pub = None def __init__(self, sec): self.my_key = sec class TCPPacket(object): ''' Add Documentation here ''' SMALLEST_STARTING_SEQ = 0 HIGHEST_STARTING_SEQ = 4294967295 def __init__(self): # self.src_port = src_port # 16bit # self.dst_port = dst_port # 16bit self.seq = TCPPacket.gen_starting_seq_num() # 32bit self.ack = 0 # 32bit self.data_offset = 0 # 4 bits self.reserved_field = 0 # 3bits saved for future use must be zero assert self.reserved_field = 0 #FLAGS self.flag_ns = 0 # 1bit self.flag_cwr = 0 # 1bit self.flag_ece = 0 # 1bit self.flag_urg = 0 # 1bit self.flag_ack = 0 # 1bit self.flag_psh = 0 # 1bit self.flag_rst = 0 # 1bit self.flag_syn = 0 # 1bit self.flag_fin = 0 # 1bit #window size self.window_size = 0 # 16bit #checksum self.checksum = 0 # 16bit #urgent pointer self.urgent_pointer = 0 # 16bit #options self.options = 0 # 0-320bits, divisible by 32 #padding - TCP packet must be on a 32bit boundary this ensures that it is the padding is filled with 0's self.padding = 0 # as much as needed self.data = b'' def __repr__(self): return 'TCPpacket()' def __str__(self): return 'SEQ Number: %d, ACK Number: %d, ACK:%d, SYN:%d, FIN:%d, TYPE:%s, DATA:%s' \ % (self.seq, self.ack, self.flag_ack, self.flag_syn, self.flag_fin, self.packet_type(), self.data) def __cmp__(self, other): return (self.seq > other.seq) - (self.seq < other.seq) def packet_type(self): packet_type = '' if self.flag_syn == 1 and self.flag_ack == 1: packet_type = 'SYN-ACK' elif self.flag_ack == 1 and self.flag_fin == 1: packet_type = 'FIN-ACK' elif self.flag_syn == 1: packet_type = 'SYN' elif self.flag_ack == 1: packet_type = 'ACK' elif self.flag_fin == 1: packet_type = 'FIN' elif self.data != b'': packet_type = 'DATA' return packet_type def set_flags(self, ack=False, syn=False, fin=False): if ack: self.flag_ack = 1 else: self.flag_ack = 0 if syn: self.flag_syn = 1 else: self.flag_syn = 0 if fin: self.flag_fin = 1 else: self.flag_fin = 0 @staticmethod def gen_starting_seq_num(): return random.randint(TCPPacket.SMALLEST_STARTING_SEQ, TCPPacket.HIGHEST_STARTING_SEQ) class RestrictedUnpickler(pickle.Unpickler): def find_class(self, module, name): # Only allow TCPPacket if name == 'TCPPacket': return TCPPacket # Forbid everything else. raise pickle.UnpicklingError("global '%s.%s' is forbidden" % (module, name)) def restricted_pickle_loads(s): """Helper function analogous to pickle.loads().""" return RestrictedUnpickler(io.BytesIO(s)).load() class ConnectedSOCK(object): def __init__(self, low_sock, client_addr): self.client_addr = client_addr self.low_sock = low_sock def __getattribute__(self, att): try: return object.__getattribute__(self, att) except AttributeError: return getattr(self.low_sock, att) def getpeername(self): return self.client_addr def send(self, data): if self.closed: raise EOFError return self.low_sock.send(data, self.client_addr) def recv(self, size): if self.closed: raise EOFError return self.low_sock.recv(size, self.client_addr) @property def closed(self): return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin) def close(self): if self.client_addr in self.low_sock.connections: self.low_sock.close(self.client_addr) def shutdown(self, *a, **kw): self.close() def poll(self, timeout): if self.client_addr in self.packets_received['DATA or FIN']: return True else: self.incoming_packet_event.wait(timeout) return self.client_addr in self.packets_received['DATA or FIN'] return False class TCP(object): host = None port = None client = False peer_keypair = {} connections = {} connection_queue = [] packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} def __init__(self, af_type=None, sock_type=None, encrypted=False): self.encrypted = encrypted self.incoming_packet_event = threading.Event() self.new_conn_event = threading.Event() #seq will have the last packet send and ack will have the next packet waiting to receive self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication. self.settimeout() #self.peer_keypair = {} #self.connections = {} #self.connection_queue = [] self.connection_lock = threading.Lock() self.queue_lock = threading.Lock() # each condition will have a dictionary of an address and it's corresponding packet. #self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} def poll(self, timeout): if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']: return True else: self.incoming_packet_event.wait(timeout) if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']: return True return False def get_free_port(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.bind(('', 0)) port = s.getsockname()[1] s.close() return port def setsockopt(self, *a, **kw): pass def settimeout(self, timeout=5): self.own_socket.settimeout(timeout) def setblocking(self, mode): self.own_socket.setblocking(mode) def __repr__(self): return 'TCP()' def __str__(self): return 'Connections: %s' \ % str(self.connections) def getsockname(self): return self.own_socket.getsockname() def getpeername(self): if len(self.connections): return list(self.connections.keys())[0] else: raise EOFError('Not connected') def bind(self, addr): self.host = addr[0] self.port = addr[1] self.own_socket.bind(addr) def send(self, data, connection=None): try: if connection not in list(self.connections.keys()): if connection is None: connection = list(self.connections.keys())[0] else: raise EOFError('Connection not in connected devices') data_parts = TCP.data_divider(data) for data_part in data_parts: data_not_received = True data_chunk = data_part if not self.encrypted else self.peer_keypair[connection].peer_pub.encrypt_raw(data_part) checksum_of_data = TCP.checksum(data_chunk) self.connections[connection].checksum = checksum_of_data self.connections[connection].data = data_chunk self.connections[connection].set_flags() packet_to_send = pickle.dumps(self.connections[connection]) self.connections[connection].checksum = 0 self.connections[connection].data = b'' retransmit_count = 0 while data_not_received and retransmit_count < 3: data_not_received = False try: self.own_socket.sendto(packet_to_send, connection) answer = self.find_correct_packet('ACK', connection) if not answer: data_not_received = True retransmit_count += 1 except socket.timeout: #print('timeout') data_not_received = True if not answer: self.drop_connection(connection) raise EOFError('Connection lost') self.connections[connection].seq += len(data_part) return len(data) except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) def recv(self, size, connection=None): try: if connection not in list(self.connections.keys()): if connection is None: connection = list(self.connections.keys())[0] else: raise EOFError('Connection not in connected devices') data = self.find_correct_packet('DATA or FIN', connection, size) if not self.status: raise EOFError('Disconnecting') return data except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) def send_ack(self, connection, ack): self.connections[connection].ack = ack self.connections[connection].seq += 1 self.connections[connection].set_flags(ack=True) self.connections[connection].data = b'' packet_to_send = pickle.dumps(self.connections[connection]) self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack self.connections[connection].set_flags() def listen_handler(self, max_connections): try: while True and self.status: try: answer, address = self.find_correct_packet('SYN') with self.queue_lock: if len(self.connection_queue) < max_connections: if self.encrypted: try: peer_pub = answer.data self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key (~5 sec) self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub) except: self.peer_keypair.pop(address) raise socket.error('Init peer public key error') self.connection_queue.append((answer, address)) self.blink_new_conn_event() else: self.own_socket.sendto('Connections full', address) except KeyError: continue except socket.error as error: raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error)) def listen(self, max_connections=1): self.status = 1 self.central_receive() try: t = threading.Thread(target=self.listen_handler, args=(max_connections,)) t.daemon = True t.start() except Exception as error: raise EOFError('Something went wrong in listen func! Error is: %s.' % str(error)) def accept(self): try: while True: self.new_conn_event.wait(0.1) if self.connection_queue: with self.queue_lock: answer, address = self.connection_queue.pop() self.connections[address] = TCPPacket() self.connections[address].ack = answer.seq + 1 self.connections[address].seq += 1 self.connections[address].set_flags(ack=True, syn=True) if self.encrypted: pubkey = self.peer_keypair[address].my_key.publickey.serialize() self.connections[address].data = self.peer_keypair[address].peer_pub.encrypt_raw(pubkey) self.connections[address].checksum = TCP.checksum(self.connections[address].data) packet_to_send = pickle.dumps(self.connections[address]) if self.encrypted: self.connections[address].data = b'' self.connections[address].checksum = 0 #lock address, connections dictionary? packet_not_sent_correctly = True while packet_not_sent_correctly or answer is None: try: packet_not_sent_correctly = False self.own_socket.sendto(packet_to_send, address) answer = self.find_correct_packet('ACK', address) except socket.timeout: packet_not_sent_correctly = True self.connections[address].set_flags() self.connections[address].ack = answer.seq + 1 return ConnectedSOCK(self, address), address except Exception as error: self.close(address) raise EOFError('Something went wrong in accept func: ' + str(error)) def connect(self, server_address=('127.0.0.1', 10000)): try: self.bind(('', self.get_free_port())) self.status = 1 self.client = True self.central_receive() self.connections[server_address] = TCPPacket() self.connections[server_address].set_flags(syn=True) if self.encrypted: self.peer_keypair[server_address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: but here it creates the key quickly pubkey = self.peer_keypair[server_address].my_key.publickey.serialize() pub_checksum = TCP.checksum(pubkey) self.connections[server_address].checksum = pub_checksum self.connections[server_address].data = pubkey first_packet_to_send = pickle.dumps(self.connections[server_address]) self.connections[server_address].data = b'' self.connections[server_address].checksum = 0 self.own_socket.sendto(first_packet_to_send, list(self.connections.keys())[FIRST]) self.connections[server_address].set_flags() answer = self.find_correct_packet('SYN-ACK', server_address) if type(answer) == str: # == 'Connections full': raise socket.error('Server cant receive any connections right now.') if self.encrypted: try: peer_pub = self.peer_keypair[server_address].my_key.decrypt_raw(answer.data) self.peer_keypair[server_address].peer_pub = simplecrypto.RsaPublicKey(peer_pub) except: raise socket.error('Decrypt peer public key error') if not peer_pub or answer.checksum != TCP.checksum(answer.data): raise socket.error('Invalid peer public key') self.connections[server_address].ack = answer.seq + 1 self.connections[server_address].seq += 1 self.connections[server_address].set_flags(ack=True) second_packet_to_send = pickle.dumps(self.connections[server_address]) self.own_socket.sendto(second_packet_to_send, list(self.connections.keys())[FIRST]) self.connections[server_address].set_flags() except socket.error as error: self.own_socket.close() self.connections = {} self.peer_keypair = {} self.status = 0 raise EOFError('The socket was closed. Error:' + str(error)) def fileno(self): return self.own_socket.fileno() @property def closed(self): return self.own_socket._closed def drop_connection(self, connection): with self.connection_lock: if len(self.connections): self.connections.pop(connection) if len(self.peer_keypair): self.peer_keypair.pop(connection) def close(self, connection=None): try: if connection not in list(self.connections.keys()): if connection is None: connection = list(self.connections.keys())[0] else: raise EOFError('Connection not in connected devices') self.connections[connection].set_flags(fin=True) self.connections[connection].seq += 1 packet_to_send = pickle.dumps(self.connections[connection]) self.own_socket.sendto(packet_to_send, connection) answer = self.find_correct_packet('ACK', connection) # change cause may get a None value self.connections[connection].ack += 1 answer = self.find_correct_packet('FIN-ACK', connection) if answer.flag_fin != 1: raise Exception('The receiver didn\'t send the fin packet') else: self.send_ack(connection, self.connections[connection].ack + 1) self.drop_connection(connection) if len(self.connections) == 0 and self.client: self.own_socket.close() self.status = 0 except Exception as error: raise EOFError('Something went wrong in the close func! Error is: %s.' % error) def disconnect(self, connection): try: self.connections[connection].ack += 1 self.connections[connection].seq += 1 self.connections[connection].set_flags(ack=True) packet_to_send = pickle.dumps(self.connections[connection]) self.own_socket.sendto(packet_to_send, connection) self.connections[connection].set_flags(fin=True, ack=True) self.connections[connection].seq += 1 packet_to_send = pickle.dumps(self.connections[connection]) self.own_socket.sendto(packet_to_send, connection) answer = self.find_correct_packet('ACK', connection) with self.connection_lock: self.connections.pop(connection) except Exception as error: raise EOFError('Something went wrong in disconnect func:%s ' % error) @staticmethod def data_divider(data): '''Divides the data into a list where each element's length is 1024''' data = [data[i:i + DATA_DIVIDE_LENGTH] for i in range(0, len(data), DATA_DIVIDE_LENGTH)] data.append(PACKET_END) return data @staticmethod def checksum(source_bytes): return hashlib.sha1(source_bytes).digest() def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH): not_found = True tries = 0 while not_found and tries < 2 and self.status: try: not_found = False if address[0] == 'Any': order = self.packets_received[condition].popitem() # to reverse the tuple received return order[1], order[0] if condition == 'ACK': tries += 1 if condition == 'DATA or FIN': with self.connection_lock: packet = self.packets_received[condition][address][:size] self.packets_received[condition][address] = self.packets_received[condition][address][size:] if not len(self.packets_received[condition][address]): del self.packets_received[condition][address] else: packet = self.packets_received[condition].pop(address) return packet except KeyError: not_found = True self.incoming_packet_event.wait(0.1) def blink_incoming_packet_event(self): self.incoming_packet_event.set() self.incoming_packet_event.clear() def blink_new_conn_event(self): self.new_conn_event.set() self.new_conn_event.clear() def sort_answers(self, packet, address): if address not in self.connections and packet.packet_type() != 'SYN': return if packet.packet_type() == 'FIN': self.disconnect(address) elif packet.packet_type() == 'DATA': if packet.checksum == TCP.checksum(packet.data): data_chunk = packet.data if not self.encrypted else self.peer_keypair[address].my_key.decrypt_raw(packet.data) if data_chunk != PACKET_END: with self.connection_lock: if address not in self.packets_received['DATA or FIN']: self.packets_received['DATA or FIN'][address] = b'' self.packets_received['DATA or FIN'][address] += data_chunk self.send_ack(address, packet.seq + len(packet.data)) self.blink_incoming_packet_event() elif packet.packet_type() == '': #print('redundant packet found', packet) pass else: self.packets_received[packet.packet_type()][address] = packet self.blink_incoming_packet_event() def central_receive_handler(self): while True and self.status: try: packet, address = self.own_socket.recvfrom(SENT_SIZE) packet = restricted_pickle_loads(packet) self.sort_answers(packet, address) except socket.timeout: continue except socket.error as error: self.own_socket.close() self.status = 0 # print('An error has occured: Socket error %s' % error) def central_receive(self): t = threading.Thread(target=self.central_receive_handler) t.daemon = True t.start()