From f7a4da82bd9e09fb9b50cd81aa193548bfe0bb68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A0=D0=BE=D0=BC=D0=B0=D0=BD=20=D0=91=D0=BE=D1=80=D0=BE?= =?UTF-8?q?=D0=B4=D0=B8=D0=BD?= Date: Mon, 8 Apr 2019 08:10:23 +0300 Subject: [PATCH] encrypted "tcp" over udp --- mods/__init__.py | 0 mods/rpyc_dtls.py | 153 +++++++++++++++++++++++++++++++++----- mods/rpyc_utcp.py | 74 +++++++++++++++++++ mods/utcp.py | 185 ++++++++++++++++++++++++++-------------------- 4 files changed, 315 insertions(+), 97 deletions(-) create mode 100644 mods/__init__.py create mode 100644 mods/rpyc_utcp.py diff --git a/mods/__init__.py b/mods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mods/rpyc_dtls.py b/mods/rpyc_dtls.py index 224277d..7151af5 100644 --- a/mods/rpyc_dtls.py +++ b/mods/rpyc_dtls.py @@ -1,6 +1,12 @@ +# Based on rpyc (4.0.2) + +# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!! + import rpyc from rpyc.utils.server import ThreadedServer, spawn -from rpyc.core.stream import SocketStream +from rpyc.core import SocketStream, Channel +from rpyc.core.stream import retry_errnos +from rpyc.utils.factory import connect_channel import socket from mbedtls import tls import datetime as dt @@ -10,6 +16,10 @@ from mbedtls import x509 from uuid import uuid4 from contextlib import suppress import sys +from rpyc.lib import safe_import +zlib = safe_import("zlib") +import errno +from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL def block(callback, *args, **kwargs): while True: @@ -33,6 +43,7 @@ class DTLSCerts: self.ca1_crt = self.ca0_crt.sign( ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456, basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3)) + self.srv_crt, self.srv_key = self.server_cert() def server_cert(self): now = dt.datetime.utcnow() ee0_key = pk.ECC() @@ -46,17 +57,28 @@ dtls_certs = DTLSCerts() trust_store = tls.TrustStore() trust_store.add(dtls_certs.ca0_crt) +srv_ctx_conf = tls.DTLSConfiguration( + trust_store=trust_store, + certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key), + validate_certificates=False, + ) +cli_ctx_conf = tls.DTLSConfiguration( + trust_store=trust_store, + validate_certificates=False, + ) +MAX_IO_CHUNK = 20971520 + class DTLSSocketStream(SocketStream): + MAX_IO_CHUNK = MAX_IO_CHUNK @classmethod def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs): if kwargs.pop('ipv6', False): family = socket.AF_INET6 else: family = socket.AF_INET - dtls_cli_ctx = tls.ClientContext(tls.DTLSConfiguration( - trust_store=trust_store, - validate_certificates=False, - )) + #tls._enable_debug_output(cli_ctx_conf) + #tls._set_debug_level(10) + dtls_cli_ctx = tls.ClientContext(cli_ctx_conf) dtls_cli = dtls_cli_ctx.wrap_socket( socket.socket(family, socket.SOCK_DGRAM), server_hostname=None, @@ -66,33 +88,128 @@ class DTLSSocketStream(SocketStream): block(dtls_cli.do_handshake) return cls(dtls_cli) def read(self, count): - return block(SocketStream.read(self, count)) + while True: + try: + buf = block(self.sock.recv, min(self.MAX_IO_CHUNK, count)) + except socket.timeout: + continue + except socket.error: + ex = sys.exc_info()[1] + if get_exc_errno(ex) in retry_errnos: + # windows just has to be a bitch + # inpos: I agree + continue + self.close() + raise EOFError(ex) + else: + break + if not buf: + self.close() + raise EOFError("connection closed by peer") + return buf def write(self, data): - block(SocketStream.write(self, data)) + try: + _ = block(self.sock.send, data[:self.MAX_IO_CHUNK]) + except socket.error: + ex = sys.exc_info()[1] + self.close() + raise EOFError(ex) + +class DTLSChannel(Channel): + MAX_IO_CHUNK = MAX_IO_CHUNK + def recv(self): + raw_data = self.stream.read(self.MAX_IO_CHUNK) + header = raw_data[:self.FRAME_HEADER.size] + raw_data = raw_data[self.FRAME_HEADER.size:] + length, compressed = self.FRAME_HEADER.unpack(header) + data = raw_data[:length] + if compressed: + data = zlib.decompress(data) + return data + +def connect_stream(stream, service=rpyc.VoidService, config={}): + return connect_channel(DTLSChannel(stream), service=service, config=config) + +def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None, + cert_reqs=None, ssl_version=None, ciphers=None, + service=rpyc.VoidService, config={}, ipv6=False, keepalive=False): + ssl_kwargs = {'server_side' : False} + if ciphers is not None: + ssl_kwargs['ciphers'] = ciphers + s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive) + return connect_stream(s, service, config) class DTLSThreadedServer(ThreadedServer): - def dtls(self, listener_timeout = 0.5, reuse_addr = True): + def __init__(self, service, hostname = "", ipv6 = False, port = 0, + backlog = 10, reuse_addr = True, authenticator = None, registrar = None, + auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5, + socket_path = None): + ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port, + backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar, + auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout, + socket_path=socket_path) self.listener.close() - srv_crt, srv_key = dtls_certs.server_cert() - dtls_srv_ctx = tls.ServerContext(tls.DTLSConfiguration( - trust_store=trust_store, - certificate_chain=([srv_crt, dtls_certs.ca1_crt], srv_key), - validate_certificates=False, - )) + + self.host = hostname + self.port = port + #tls._enable_debug_output(srv_ctx_conf) + #tls._set_debug_level(10) + dtls_srv_ctx = tls.ServerContext(srv_ctx_conf) dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) if reuse_addr and sys.platform != 'win32': dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - dtls_srv.bind((self.host, self.port)) - dtls_srv.settimeout(listener_timeout) + dtls_srv.bind((hostname, port)) + # dtls_srv.settimeout(listener_timeout) self.listener = dtls_srv + sockname = self.listener.getsockname() + self.host, self.port = sockname[0], sockname[1] + def _serve_client(self, sock, credentials): + addrinfo = sock.getpeername() + if credentials: + self.logger.info("welcome %s (%r)", addrinfo, credentials) + else: + self.logger.info("welcome %s", addrinfo) + try: + config = dict(self.protocol_config, credentials = credentials, + endpoints = (sock.getsockname(), addrinfo), logger = self.logger) + conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config) + self._handle_connection(conn) + finally: + self.logger.info("goodbye %s", addrinfo) def _listen(self): if self.active: return + #self.listener.listen(self.backlog) ##################### if not self.port: self.port = self.listener.getsockname()[1] self.logger.info('server started on [%s]:%s', self.host, self.port) self.active = True - def _authenticate_and_serve_client(self, sock): + def accept(self): + while self.active: + try: + print('Accepting new connections!') + sock, addrinfo = self.listener.accept() + except socket.timeout: + pass + except socket.error: + ex = sys.exc_info()[1] + if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN): + pass + else: + raise EOFError() + raise + else: + break + + if not self.active: + return + + sock.setblocking(True) + self.logger.info("accepted %s with fd %s", addrinfo, sock.fileno()) + print("accepted %s with fd %s" % (addrinfo, sock.fileno())) + self.clients.add(sock) + self._accept_method(sock) + def _accept_method(self, sock): addr = sock.getpeername() sock.setcookieparam(addr[0].encode()) with suppress(tls.HelloVerifyRequest): @@ -100,6 +217,6 @@ class DTLSThreadedServer(ThreadedServer): sock, addr = sock.accept() sock.setcookieparam(addr[0].encode()) block(sock.do_handshake) - ThreadedServer._authenticate_and_serve_client(self, sock) + spawn(self._authenticate_and_serve_client, sock) diff --git a/mods/rpyc_utcp.py b/mods/rpyc_utcp.py new file mode 100644 index 0000000..e68f130 --- /dev/null +++ b/mods/rpyc_utcp.py @@ -0,0 +1,74 @@ +import rpyc +from rpyc.utils.server import ThreadedServer +from rpyc.core import SocketStream, Channel +from rpyc.core.stream import retry_errnos +from rpyc.utils.factory import connect_channel +from rpyc.lib import Timeout +from rpyc.lib.compat import select_error +from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL +from . import utcp +import sys +import errno +import socket + +class UTCPSocketStream(SocketStream): + MAX_IO_CHUNK = utcp.DATA_LENGTH + @classmethod + def utcp_connect(cls, host, port, *a, **kw): + sock = utcp.TCP(encrypted=True) + sock.connect((host, port)) + return cls(sock) + def poll(self, timeout): + timeout = Timeout(timeout) + try: + while True: + try: + rl = self.sock.poll(timeout.timeleft()) + except select_error: + ex = sys.exc_info()[1] + if ex.args[0] == errno.EINTR: + continue + else: + raise + else: + break + except ValueError: + ex = sys.exc_info()[1] + raise select_error(str(ex)) + return rl +def connect_stream(stream, service=rpyc.VoidService, config={}): + return connect_channel(Channel(stream), service=service, config=config) +def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw): + s = UTCPSocketStream.utcp_connect(host, port, **kw) + return connect_stream(s, service, config) +class UTCPThreadedServer(ThreadedServer): + def __init__(self, service, hostname = '', ipv6 = False, port = 0, + backlog = 1, reuse_addr = True, authenticator = None, registrar = None, + auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5, + socket_path = None): + backlog = 1 + ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port, + backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar, + auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout, + socket_path=socket_path) + self.listener.close() + self.listener = None + ########## + self.listener = utcp.TCP(encrypted=True) + self.listener.bind((hostname, port)) + sockname = self.listener.getsockname() + self.host, self.port = sockname[0], sockname[1] + def _serve_client(self, sock, credentials): + addrinfo = sock.getpeername() + if credentials: + self.logger.info("welcome %s (%r)", addrinfo, credentials) + else: + self.logger.info("welcome %s", addrinfo) + try: + config = dict(self.protocol_config, credentials = credentials, + endpoints = (sock.getsockname(), addrinfo), logger = self.logger) + conn = self.service._connect(Channel(UTCPSocketStream(sock)), config) + self._handle_connection(conn) + finally: + self.logger.info("goodbye %s", addrinfo) + \ No newline at end of file diff --git a/mods/utcp.py b/mods/utcp.py index a7dbd5a..68dd347 100644 --- a/mods/utcp.py +++ b/mods/utcp.py @@ -6,8 +6,6 @@ import threading import io import hashlib import simplecrypto -from datetime import datetime - DATA_DIVIDE_LENGTH = 8000 PACKET_HEADER_SIZE = 512 # Pickle service info @@ -109,7 +107,7 @@ class TCPPacket(object): class RestrictedUnpickler(pickle.Unpickler): def find_class(self, module, name): # Only allow TCPPacket - if module == "utcp" and name == 'TCPPacket': + if name == 'TCPPacket': return TCPPacket # Forbid everything else. raise pickle.UnpicklingError("global '%s.%s' is forbidden" % @@ -123,44 +121,67 @@ class ConnectedSOCK(object): self.client_addr = client_addr self.low_sock = low_sock def __getattribute__(self, att): - if not att.startswith('_') and not att in ['client_addr', 'low_sock', 'send', 'recv', 'close', 'closed']: - if att in self.low_sock.__dict__: - return getattr(self.low_sock, att) - return object.__getattribute__(self, att) + try: + return object.__getattribute__(self, att) + except AttributeError: + return getattr(self.low_sock, att) + def getpeername(self): + return self.client_addr def send(self, data): if self.closed: raise EOFError - self.low_sock.send(data, self.client_addr) - def recv(self, size=None): + return self.low_sock.send(data, self.client_addr) + def recv(self, size): if self.closed: raise EOFError - if size: - return self.low_sock.recv(self.client_addr)[:size] - else: - return self.low_sock.recv(self.client_addr) + return self.low_sock.recv(size, self.client_addr) @property def closed(self): return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin) def close(self): - self.low_sock.close(self.client_addr) + if self.client_addr in self.low_sock.connections: + self.low_sock.close(self.client_addr) + def shutdown(self, *a, **kw): + self.close() + def poll(self, timeout): + if self.client_addr in self.packets_received['DATA or FIN']: + return True + else: + self.incoming_packet_event.wait(timeout) + return self.client_addr in self.packets_received['DATA or FIN'] + return False class TCP(object): host = None port = None client = False + peer_keypair = {} + connections = {} + connection_queue = [] + packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} def __init__(self, af_type=None, sock_type=None, encrypted=False): self.encrypted = encrypted self.incoming_packet_event = threading.Event() + self.new_conn_event = threading.Event() #seq will have the last packet send and ack will have the next packet waiting to receive self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication. self.settimeout() - self.peer_keypair = {} - self.connections = {} - self.connection_queue = [] + #self.peer_keypair = {} + #self.connections = {} + #self.connection_queue = [] self.connection_lock = threading.Lock() self.queue_lock = threading.Lock() # each condition will have a dictionary of an address and it's corresponding packet. - self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} + #self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} + def poll(self, timeout): + if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']: + return True + else: + self.incoming_packet_event.wait(timeout) + if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']: + return True + return False + def get_free_port(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.bind(('', 0)) @@ -171,6 +192,8 @@ class TCP(object): pass def settimeout(self, timeout=5): self.own_socket.settimeout(timeout) + def setblocking(self, mode): + self.own_socket.setblocking(mode) def __repr__(self): return 'TCP()' @@ -178,7 +201,12 @@ class TCP(object): return 'Connections: %s' \ % str(self.connections) def getsockname(self): - return (self.host, self.port) + return self.own_socket.getsockname() + def getpeername(self): + if len(self.connections): + return list(self.connections.keys())[0] + else: + raise EOFError('Not connected') def bind(self, addr): self.host = addr[0] self.port = addr[1] @@ -201,60 +229,48 @@ class TCP(object): packet_to_send = pickle.dumps(self.connections[connection]) self.connections[connection].checksum = 0 self.connections[connection].data = b'' - while data_not_received: + retransmit_count = 0 + while data_not_received and retransmit_count < 3: data_not_received = False try: self.own_socket.sendto(packet_to_send, connection) answer = self.find_correct_packet('ACK', connection) + if not answer: + data_not_received = True + retransmit_count += 1 except socket.timeout: #print('timeout') data_not_received = True + if not answer: + self.drop_connection(connection) + raise EOFError('Connection lost') self.connections[connection].seq += len(data_part) + return len(data) except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) - def recv(self, connection=None): + def recv(self, size, connection=None): try: - data = b'' if connection not in list(self.connections.keys()): if connection is None: connection = list(self.connections.keys())[0] else: - return 'Connection not in connected devices' - - while True and self.status: - data_part = self.find_correct_packet('DATA or FIN', connection) - if not self.status: - # print('I am disconnectiong cause sock is dead') - raise EOFError('Disconnected') - if data_part.packet_type() == 'FIN': - self.disconnect(connection) - raise EOFError('Disconnected') - checksum_value = TCP.checksum(data_part.data) - - while checksum_value != data_part.checksum: - data_part = self.find_correct_packet('DATA or FIN', connection) - checksum_value = TCP.checksum(data_part.data) - - data_chunk = data_part.data if not self.encrypted else self.peer_keypair[connection].my_key.decrypt_raw(data_part.data) - if data_chunk != PACKET_END: - data += data_chunk - self.connections[connection].ack = data_part.seq + len(data_part.data) - self.connections[connection].seq += 1 # syn flag is 1 byte - self.connections[connection].set_flags(ack=True) - self.connections[connection].data = b'' - packet_to_send = pickle.dumps(self.connections[connection]) - self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack - self.connections[connection].set_flags() - - if data_chunk == PACKET_END: - break + raise EOFError('Connection not in connected devices') + data = self.find_correct_packet('DATA or FIN', connection, size) + if not self.status: + raise EOFError('Disconnecting') return data except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) - - # conditions = ['SYN', 'SYN-ACK', 'ACK', 'FIN', 'FIN-ACK', 'DATA'] + def send_ack(self, connection, ack): + self.connections[connection].ack = ack + self.connections[connection].seq += 1 + self.connections[connection].set_flags(ack=True) + self.connections[connection].data = b'' + packet_to_send = pickle.dumps(self.connections[connection]) + self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack + self.connections[connection].set_flags() def listen_handler(self, max_connections): try: while True and self.status: @@ -265,12 +281,13 @@ class TCP(object): if self.encrypted: try: peer_pub = answer.data - self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key + self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key (~5 sec) self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub) except: self.peer_keypair.pop(address) raise socket.error('Init peer public key error') self.connection_queue.append((answer, address)) + self.blink_new_conn_event() else: self.own_socket.sendto('Connections full', address) except KeyError: @@ -291,6 +308,7 @@ class TCP(object): def accept(self): try: while True: + self.new_conn_event.wait(0.1) if self.connection_queue: with self.queue_lock: answer, address = self.connection_queue.pop() @@ -367,14 +385,17 @@ class TCP(object): self.peer_keypair = {} self.status = 0 raise EOFError('The socket was closed. Error:' + str(error)) - def shutdown(self, *a, **kw): - self.own_socket.close() - self.status = 0 def fileno(self): return self.own_socket.fileno() @property def closed(self): return self.own_socket._closed + def drop_connection(self, connection): + with self.connection_lock: + if len(self.connections): + self.connections.pop(connection) + if len(self.peer_keypair): + self.peer_keypair.pop(connection) def close(self, connection=None): try: if connection not in list(self.connections.keys()): @@ -392,19 +413,11 @@ class TCP(object): if answer.flag_fin != 1: raise Exception('The receiver didn\'t send the fin packet') else: - self.connections[connection].ack += 1 - self.connections[connection].seq += 1 - self.connections[connection].set_flags(ack=True) - packet_to_send = pickle.dumps(self.connections[connection]) - self.own_socket.sendto(packet_to_send, connection) - with self.connection_lock: - if len(self.connections): - self.connections.pop(connection) - if len(self.peer_keypair): - self.peer_keypair.pop(connection) - #if len(self.connections) == 0 and self.client: - # self.own_socket.close() - # self.status = 0 + self.send_ack(connection, self.connections[connection].ack + 1) + self.drop_connection(connection) + if len(self.connections) == 0 and self.client: + self.own_socket.close() + self.status = 0 except Exception as error: raise EOFError('Something went wrong in the close func! Error is: %s.' % error) @@ -436,7 +449,7 @@ class TCP(object): def checksum(source_bytes): return hashlib.sha1(source_bytes).digest() - def find_correct_packet(self, condition, address=('Any',)): + def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH): not_found = True tries = 0 while not_found and tries < 2 and self.status: @@ -448,7 +461,9 @@ class TCP(object): if condition == 'ACK': tries += 1 if condition == 'DATA or FIN': - packet = self.packets_received[condition][address].pop() + with self.connection_lock: + packet = self.packets_received[condition][address][:size] + self.packets_received[condition][address] = self.packets_received[condition][address][size:] if not len(self.packets_received[condition][address]): del self.packets_received[condition][address] else: @@ -457,21 +472,33 @@ class TCP(object): except KeyError: not_found = True self.incoming_packet_event.wait(0.1) - def blink_event(self): + def blink_incoming_packet_event(self): self.incoming_packet_event.set() self.incoming_packet_event.clear() + def blink_new_conn_event(self): + self.new_conn_event.set() + self.new_conn_event.clear() def sort_answers(self, packet, address): - if packet.packet_type() == 'DATA' or packet.packet_type() == 'FIN': - if address not in self.packets_received['DATA or FIN']: - self.packets_received['DATA or FIN'][address] = [] - self.packets_received['DATA or FIN'][address].insert(0, packet) - self.blink_event() + if address not in self.connections and packet.packet_type() != 'SYN': + return + if packet.packet_type() == 'FIN': + self.disconnect(address) + elif packet.packet_type() == 'DATA': + if packet.checksum == TCP.checksum(packet.data): + data_chunk = packet.data if not self.encrypted else self.peer_keypair[address].my_key.decrypt_raw(packet.data) + if data_chunk != PACKET_END: + with self.connection_lock: + if address not in self.packets_received['DATA or FIN']: + self.packets_received['DATA or FIN'][address] = b'' + self.packets_received['DATA or FIN'][address] += data_chunk + self.send_ack(address, packet.seq + len(packet.data)) + self.blink_incoming_packet_event() elif packet.packet_type() == '': #print('redundant packet found', packet) pass else: self.packets_received[packet.packet_type()][address] = packet - self.blink_event() + self.blink_incoming_packet_event() def central_receive_handler(self): while True and self.status: