diff --git a/eudp/__init__.py b/eudp/__init__.py new file mode 100644 index 0000000..5254713 --- /dev/null +++ b/eudp/__init__.py @@ -0,0 +1,590 @@ +# based on https://github.com/ethay012/EUDP-over-UDP +import random +import socket +import pickle +import threading +import io +import hashlib +import simplecrypto +from struct import Struct +import uuid +import bisect + +DATA_DIVIDE_LENGTH = 8000 +PACKET_HEADER_SIZE = 512 # Pickle service info +DATA_LENGTH = DATA_DIVIDE_LENGTH +SENT_SIZE = PACKET_HEADER_SIZE + DATA_LENGTH + 272 # Encrypted data always 272 bytes bigger + +class Connection: + SMALLEST_STARTING_SEQ = 0 + HIGHEST_STARTING_SEQ = 4294967295 + def __init__(self, remote, encrypted=False): + self.fileno = 0 + self.peer_addr = remote + self.seq = Connection.gen_starting_seq_num() + self.recv_seq = -1 + self.my_key = None + if encrypted: + self.my_key = simplecrypto.RsaKeypair() + self.pubkey = self.my_key.publickey.serialize() + self.peer_pub = None + self.recv_lock = threading.Lock() + self.send_lock = threading.Lock() + self.packet_buffer = { + 'ACK': [], + 'SYN-ACK': [], + 'DATA': [], + 'FIN-ACK': [] + } + @staticmethod + def gen_starting_seq_num(): + return random.randint(Connection.SMALLEST_STARTING_SEQ, Connection.HIGHEST_STARTING_SEQ) + def seq_inc(self, inc=1): + self.seq += inc + return self.seq + def set_ack(self, ack): + self.ack = ack + return ack + +class EUDPPacket: + def __cmp__(self, other): + return (self.seq > other.seq) - (self.seq < other.seq) + def __lt__(self, other): + return self.seq < other.seq + def __gt__(self, other): + return self.seq > other.seq + def __eq__(self, other): + return self.seq == other.seq + +class Ack(EUDPPacket): + type = 'ACK' + def __init__(self, id_): + self.id = id_ + self.seq = 0 +class Fin(EUDPPacket): + type = 'FIN' + def __init__(self): + self.id = uuid.uuid4().bytes + self.seq = 0 +class FinAck(EUDPPacket): + type = 'FIN-ACK' + def __init__(self, id_): + self.id = id_ + self.seq = 0 +class Syn(EUDPPacket): + type = 'SYN' + checksum = None + def __init__(self): + self.id = uuid.uuid4().bytes + self.seq = 0 + def set_pub(self, pubkey): + self.checksum = EUDP.checksum(pubkey) + self.pubkey = pubkey +class SynAck(EUDPPacket): + type = 'SYN-ACK' + checksum = None + def __init__(self, id_): + self.id = id_ + self.seq = 0 + def set_pub(self, pubkey): + self.checksum = EUDP.checksum(pubkey) + self.pubkey = pubkey +class Packet(EUDPPacket): + type = 'DATA' + def __init__(self): + self.id = uuid.uuid4().bytes + self.checksum = 0 + self.data = b'' + self.seq = 0 + def __repr__(self): + return f'Packet(seq={self.seq})' + def __str__(self): + return 'UUID: %d, DATA:%s' % (uuid.UUID(bytes=self.id), self.data) + def set_data(self, data): + self.checksum = EUDP.checksum(data) + self.data = data + +allowed_unpickle_class = { + 'Packet': Packet, + 'Ack': Ack, + 'Fin': Fin, + 'FinAck': FinAck, + 'Syn': Syn, + 'SynAck': SynAck + } + +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if module != 'builtins' and name in allowed_unpickle_class: + return allowed_unpickle_class[name] + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % + (module, name)) +def restricted_pickle_loads(s): + return RestrictedUnpickler(io.BytesIO(s)).load() + +class EUDPChannel: + HEADER = Struct('!I') + POLL_TIMEOUT = 0.1 + def __init__(self, sock): + self.sock = sock + def poll(self): + return self.sock.poll(self.POLL_TIMEOUT) + def recv(self): + header = self.sock.recv(self.HEADER.size) + data_len = self.HEADER.unpack(header)[0] + return self.sock.recv(data_len) + def send(self, data): + header = self.HEADER.pack(len(data)) + self.sock.send(header + data) + +class ConnectedSOCK(object): + def __init__(self, low_sock, client_addr): + self.client_addr = client_addr + self.low_sock = low_sock + self.channel = EUDPChannel(self) + def __getattribute__(self, att): + try: + return object.__getattribute__(self, att) + except AttributeError: + return getattr(self.low_sock, att) + def getpeername(self): + return self.client_addr + def send(self, data): + if self.closed: + raise EOFError + return self.low_sock.send(data, self.client_addr) + def sendall(self, data): + if self.closed: + raise EOFError + self.low_sock.sendall(data, self.client_addr) + def recv(self, size): + if self.closed: + raise EOFError + return self.low_sock.recv(size, self.client_addr) + @property + def closed(self): + return self.low_sock.own_socket._closed or self.client_addr not in self.low_sock.connections + def close(self): + if self.client_addr in self.low_sock.connections: + self.low_sock.close(self.client_addr) + def shutdown(self, *a, **kw): + self.close() + def poll(self, timeout): + return self.low_sock.poll(timeout, self.client_addr) + def packets_arrived(self, packet_type): + return self.low_sock.packets_arrived(packet_type, self.client_addr) + def fileno(self): + if self.closed: + raise EOFError + return self.low_sock.fileno(self.client_addr) + +class EUDP(object): + host = None + port = None + client = False + def __init__(self,encrypted=False, **kw): + self.encrypted = encrypted + self.incoming_packet_event = threading.Event() + self.new_conn_event = threading.Event() + self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.own_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.settimeout() + self.connection_lock = threading.Lock() + self.queue_lock = threading.Lock() + self.channel = None + self.connections = {} + self.connection_queue = [] + self.syn_received = {} + self.fileno_seq = 40000000 + def next_fileno(self): + self.fileno_seq += 1 + return self.fileno_seq + def packets_arrived(self, packet_type, connection=None): + try: + conn = self.connections[connection] + except: + raise EOFError + with conn.recv_lock: + return bool(len(conn.packet_buffer[packet_type])) + def poll(self, timeout, connection=None): + if connection not in list(self.connections.keys()): + if connection is None: + connection = list(self.connections.keys())[0] + else: + raise EOFError('Connection not in connected devices') + if not self.closed: + has_data = self.packets_arrived('DATA', connection) + if has_data: + return True + else: + if not timeout: + timeout = 0.5 + while True and not self.closed: + self.incoming_packet_event.wait(timeout) + has_data = self.packets_arrived('DATA', connection) + if not has_data: + continue + else: + return has_data + else: + self.incoming_packet_event.wait(timeout) + return self.packets_arrived('DATA', connection) + return False + def get_free_port(self): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.bind(('', 0)) + port = s.getsockname()[1] + s.close() + return port + def setsockopt(self, *a, **kw): + pass + def settimeout(self, timeout=5): + self.own_socket.settimeout(timeout) + def setblocking(self, mode): + self.own_socket.setblocking(mode) + def __repr__(self): + return 'EUDP()' + + def __str__(self): + return 'Connections: %s' \ + % str(self.connections) + def getsockname(self): + return self.own_socket.getsockname() + def getpeername(self): + if len(self.connections): + return list(self.connections.keys())[0] + else: + raise EOFError('Not connected') + def bind(self, addr): + self.host = addr[0] + self.port = addr[1] + self.own_socket.bind(addr) + def send(self, data, connection=None): + if self.closed: + raise EOFError + try: + if connection not in list(self.connections.keys()): + if connection is None: + connection = list(self.connections.keys())[0] + else: + raise EOFError('Connection not in connected devices') + conn = self.connections[connection] + data_parts = EUDP.data_divider(data) + for data_part in data_parts: + data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part) + packet = Packet() + packet.set_data(data_chunk) + self.__send_packet(connection, packet, retransmit=True) + return len(data) + except socket.error as error: + raise EOFError('Socket was closed before executing command. Error is: %s.' % error) + def sendall(self, data, connection=None): + _ = self.send(data, connection) + def recv(self, size, connection=None): + if self.closed: + raise EOFError + try: + if connection not in list(self.connections.keys()): + if connection is None: + connection = list(self.connections.keys())[0] + else: + raise EOFError('Connection not in connected devices') + data = self.find_correct_packet('DATA', connection, size) + if not self.status: + raise EOFError('Disconnecting') + return data + except socket.error as error: + raise EOFError('Socket was closed before executing command. Error is: %s.' % error) + def __send_packet(self, peer_addr, packet, retransmit=False, wait_cond='ACK'): + conn = self.connections[peer_addr] + packet.seq = conn.seq_inc() + packet_to_send = pickle.dumps(packet) + if not retransmit: + self.own_socket.sendto(packet_to_send, peer_addr) + else: + data_not_received = True + retransmit_count = 0 + while data_not_received and retransmit_count < 3: + data_not_received = False + try: + self.own_socket.sendto(packet_to_send, peer_addr) + answer = self.find_correct_packet(wait_cond, peer_addr, want_id=packet.id) + if not answer: + data_not_received = True + retransmit_count += 1 + except socket.timeout: + data_not_received = True + if not answer: + self.drop_connection(peer_addr) + raise EOFError('Connection lost') + return answer + def listen_handler(self, max_connections): + try: + while True and self.status: + try: + answer, address = self.find_correct_packet('SYN') + with self.queue_lock: + if len(self.connection_queue) < max_connections: + conn = Connection(address, self.encrypted) + conn.fileno = self.next_fileno() + if self.encrypted: + try: + conn.peer_pub = simplecrypto.RsaPublicKey(answer.pubkey) + except: + raise socket.error('Init peer public key error') + self.connection_queue.append((answer, conn)) + self.blink_new_conn_event() + else: + if answer.id in map(lambda x: x[0].id, self.connection_queue): + continue + self.own_socket.sendto(b'Connections full', address) + except KeyError: + continue + except TypeError: + continue + except socket.error as error: + raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error)) + + def listen(self, max_connections=1): + self.status = 1 + self.central_receive() + try: + t = threading.Thread(target=self.listen_handler, args=(max_connections,)) + t.daemon = True + t.start() + except Exception as error: + raise EOFError('Something went wrong in listen func! Error is: %s.' % str(error)) + def stop(self): + self.own_socket.close() + self.status = 0 + def shutdown(self, *a, **kw): + self.close() + self.status = 0 + self.connections = {} + self.stop() + def accept(self): + while self.status: + try: + self.new_conn_event.wait(0.1) + if self.connection_queue: + with self.queue_lock: + answer, conn = self.connection_queue.pop() + conn.recv_seq = answer.seq + self.connections[conn.peer_addr] = conn + syn_ack = SynAck(answer.id) + if self.encrypted: + syn_ack.set_pub(conn.peer_pub.encrypt_raw(conn.pubkey)) + answer = self.__send_packet(conn.peer_addr, syn_ack, retransmit=True) + return ConnectedSOCK(self, conn.peer_addr), conn.peer_addr + except EOFError: + if conn.peer_addr in self.connections: + self.close(conn.peer_addr) + continue + except Exception as error: + if conn.peer_addr in self.connections: + self.close(conn.peer_addr) + raise EOFError('Something went wrong in accept func: ' + str(error)) + + def connect(self, server_address=('127.0.0.1', 10000)): + if server_address in self.connections: + raise EOFError('Already connected to peer') + try: + self.bind(('', self.get_free_port())) + self.status = 1 + self.client = True + self.central_receive() + conn = Connection(server_address, self.encrypted) + self.connections[server_address] = conn + syn = Syn() + if self.encrypted: + syn.set_pub(conn.pubkey) + try: + answer = self.__send_packet(server_address, syn, retransmit=True, wait_cond='SYN-ACK') + except EOFError: + raise EOFError('Remote peer unreachable') + if type(answer) == str: # == 'Connections full': + raise socket.error('Server cant receive any connections right now.') + if self.encrypted: + try: + peer_pub = conn.my_key.decrypt_raw(answer.pubkey) + conn.peer_pub = simplecrypto.RsaPublicKey(peer_pub) + except: + raise socket.error('Decrypt peer public key error') + ack = Ack(answer.id) + self.__send_packet(server_address, ack) + self.channel = EUDPChannel(self) + conn.fileno = self.next_fileno() + except socket.error as error: + self.own_socket.close() + self.connections = {} + self.status = 0 + raise EOFError('The socket was closed. Error:' + str(error)) + def fileno(self, connection=None): + if connection not in list(self.connections.keys()): + if connection is None: + connection = list(self.connections.keys())[0] + else: + raise EOFError('Connection not in connected devices') + return self.connections[connection].fileno + @property + def closed(self): + return not bool(len(self.connections)) + def drop_connection(self, connection): + with self.connection_lock: + if len(self.connections): + self.connections.pop(connection) + def close(self, connection=None): + try: + if connection not in list(self.connections.keys()): + if connection is None: + if len(self.connections): + connection = list(self.connections.keys())[0] + else: + return + else: + raise EOFError('Connection not in connected devices') + fin = Fin() + answer = self.__send_packet(connection, fin, retransmit=True) + answer = self.find_correct_packet('FIN-ACK', connection, want_id=fin.id) + if not answer: + raise Exception('The receiver didn\'t send the fin packet') + else: + self.drop_connection(connection) + if not len(self.connections) and self.client: + self.stop() + except Exception as error: + raise EOFError('Something went wrong in the close func! Error is: %s.' % error) + + def disconnect(self, connection, fin_id): + try: + ack = Ack(fin_id) + self.__send_packet(connection, ack) + fin_ack = FinAck(fin_id) + try: + self.__send_packet(connection, fin_ack) + except: + pass + self.drop_connection(connection) + except Exception as error: + raise EOFError('Something went wrong in disconnect func:%s ' % error) + + @staticmethod + def data_divider(data): + '''Divides the data into a list where each element's length is 1024''' + data = [data[i:i + DATA_DIVIDE_LENGTH] for i in range(0, len(data), DATA_DIVIDE_LENGTH)] + return data + + @staticmethod + def checksum(source_bytes): + return hashlib.sha1(source_bytes).digest() + + def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH, want_id=None): + not_found = True + tries = 0 + while not_found and tries < 2 and self.status: + try: + not_found = False + if address[0] == 'Any': + order = self.syn_received.popitem() # to reverse the tuple received + return order[1], order[0] + try: + conn = self.connections[address] + except: + break + if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']: + tries += 1 + if condition == 'DATA': + if self.poll(0.1, address): + data = b'' + while size: + if not self.poll(0.1, address): + continue + with conn.recv_lock: + packet = conn.packet_buffer[condition][0] + chunk = packet.data[:size] + chunk_len = len(chunk) + data += chunk + packet.data = packet.data[size:] + size -= chunk_len + if not len(packet.data): + try: + conn.packet_buffer[condition].pop(0) + except IndexError: + size = 0 + return data + else: + raise KeyError + else: + if self.packets_arrived(condition, address): + packet = conn.packet_buffer[condition].pop() + else: + raise KeyError + if want_id and packet.id != want_id: + raise KeyError + return packet + except KeyError: + not_found = True + self.incoming_packet_event.wait(0.5) + def blink_incoming_packet_event(self): + self.incoming_packet_event.set() + self.incoming_packet_event.clear() + def blink_new_conn_event(self): + self.new_conn_event.set() + self.new_conn_event.clear() + def multiplex(self, packet, address): + if address not in self.connections and not isinstance(packet, Syn): + return + if not isinstance(packet, Syn): + conn = self.connections[address] + if isinstance(packet, SynAck): + conn.recv_seq = packet.seq + elif conn.recv_seq == packet.seq: # Repeat ACK + ack = Ack(packet.id) + self.__send_packet(address, ack) + return + elif packet.seq < conn.recv_seq: # Possibly DUP + return + elif packet.seq > (conn.recv_seq + 1): # Intermediate packet lost + return + else: + conn.recv_seq = packet.seq + if isinstance(packet, Packet): + if packet.checksum == EUDP.checksum(packet.data): + if self.encrypted: + packet.data = conn.my_key.decrypt_raw(packet.data) + else: + return + if isinstance(packet, Fin): + self.disconnect(address, packet.id) + elif isinstance(packet, Syn): + if address in self.connections: + return + if packet.id not in map(lambda x: x.id, self.syn_received.values()): + self.syn_received[address] = packet + else: + with conn.recv_lock: + if packet.id not in map(lambda x: x.id, conn.packet_buffer[packet.type]): + bisect.insort(conn.packet_buffer[packet.type], packet) + if isinstance(packet, Packet): + ack = Ack(packet.id) + self.__send_packet(address, ack) + self.blink_incoming_packet_event() + + def central_receive_handler(self): + while True and self.status: + try: + packet, address = self.own_socket.recvfrom(SENT_SIZE) + packet = restricted_pickle_loads(packet) + self.multiplex(packet, address) + except pickle.UnpicklingError: + continue + except socket.timeout: + continue + except socket.error: + self.own_socket.close() + self.status = 0 + + def central_receive(self): + t = threading.Thread(target=self.central_receive_handler) + t.daemon = True + t.start() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..fd03d6d --- /dev/null +++ b/setup.py @@ -0,0 +1,5 @@ +from distutils.core import setup +setup(name='eudp', + version='1.0', + py_modules=['eudp'], + )