# Based on rpyc (4.0.2) # !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!! import rpyc from rpyc.utils.server import ThreadedServer, spawn from rpyc.core import SocketStream, Channel from rpyc.core.stream import retry_errnos from rpyc.utils.factory import connect_channel import socket from mbedtls import tls import datetime as dt from mbedtls import hash as hashlib from mbedtls import pk from mbedtls import x509 from uuid import uuid4 from contextlib import suppress import sys from rpyc.lib import safe_import zlib = safe_import("zlib") import errno from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL def block(callback, *args, **kwargs): while True: with suppress(tls.WantReadError, tls.WantWriteError): return callback(*args, **kwargs) class DTLSCerts: def __init__(self): now = dt.datetime.utcnow() self.ca0_key = pk.RSA() _ = self.ca0_key.generate() ca0_csr = x509.CSR.new(self.ca0_key, 'CN=Black mamba thrusted CA', hashlib.sha256()) self.ca0_crt = x509.CRT.selfsign( ca0_csr, self.ca0_key, not_before=now, not_after=now + dt.timedelta(days=3650), serial_number=0x123456, basic_constraints=x509.BasicConstraints(True, 1)) self.ca1_key = pk.ECC() _ = self.ca1_key.generate() ca1_csr = x509.CSR.new(self.ca1_key, 'CN=Black mamba intermediate CA', hashlib.sha256()) self.ca1_crt = self.ca0_crt.sign( ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456, basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3)) self.srv_crt, self.srv_key = self.server_cert() def server_cert(self): now = dt.datetime.utcnow() ee0_key = pk.ECC() _ = ee0_key.generate() ee0_csr = x509.CSR.new(ee0_key, f'CN=Black mamba peer [{uuid4().hex}]', hashlib.sha256()) ee0_crt = self.ca1_crt.sign( ee0_csr, self.ca1_key, now, now + dt.timedelta(days=3650), 0x987654) return ee0_crt, ee0_key dtls_certs = DTLSCerts() trust_store = tls.TrustStore() trust_store.add(dtls_certs.ca0_crt) srv_ctx_conf = tls.DTLSConfiguration( trust_store=trust_store, certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key), validate_certificates=False, ) cli_ctx_conf = tls.DTLSConfiguration( trust_store=trust_store, validate_certificates=False, ) MAX_IO_CHUNK = 20971520 class DTLSSocketStream(SocketStream): MAX_IO_CHUNK = MAX_IO_CHUNK @classmethod def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs): if kwargs.pop('ipv6', False): family = socket.AF_INET6 else: family = socket.AF_INET #tls._enable_debug_output(cli_ctx_conf) #tls._set_debug_level(10) dtls_cli_ctx = tls.ClientContext(cli_ctx_conf) dtls_cli = dtls_cli_ctx.wrap_socket( socket.socket(family, socket.SOCK_DGRAM), server_hostname=None, ) dtls_cli.settimeout(timeout) dtls_cli.connect((host, port)) block(dtls_cli.do_handshake) return cls(dtls_cli) def read(self, count): while True: try: buf = block(self.sock.recv, min(self.MAX_IO_CHUNK, count)) except socket.timeout: continue except socket.error: ex = sys.exc_info()[1] if get_exc_errno(ex) in retry_errnos: # windows just has to be a bitch # inpos: I agree continue self.close() raise EOFError(ex) else: break if not buf: self.close() raise EOFError("connection closed by peer") return buf def write(self, data): try: _ = block(self.sock.send, data[:self.MAX_IO_CHUNK]) except socket.error: ex = sys.exc_info()[1] self.close() raise EOFError(ex) class DTLSChannel(Channel): MAX_IO_CHUNK = MAX_IO_CHUNK def recv(self): raw_data = self.stream.read(self.MAX_IO_CHUNK) header = raw_data[:self.FRAME_HEADER.size] raw_data = raw_data[self.FRAME_HEADER.size:] length, compressed = self.FRAME_HEADER.unpack(header) data = raw_data[:length] if compressed: data = zlib.decompress(data) return data def connect_stream(stream, service=rpyc.VoidService, config={}): return connect_channel(DTLSChannel(stream), service=service, config=config) def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None, cert_reqs=None, ssl_version=None, ciphers=None, service=rpyc.VoidService, config={}, ipv6=False, keepalive=False): ssl_kwargs = {'server_side' : False} if ciphers is not None: ssl_kwargs['ciphers'] = ciphers s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive) return connect_stream(s, service, config) class DTLSThreadedServer(ThreadedServer): def __init__(self, service, hostname = "", ipv6 = False, port = 0, backlog = 10, reuse_addr = True, authenticator = None, registrar = None, auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5, socket_path = None): ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port, backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar, auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout, socket_path=socket_path) self.listener.close() self.host = hostname self.port = port #tls._enable_debug_output(srv_ctx_conf) #tls._set_debug_level(10) dtls_srv_ctx = tls.ServerContext(srv_ctx_conf) dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) if reuse_addr and sys.platform != 'win32': dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) dtls_srv.bind((hostname, port)) # dtls_srv.settimeout(listener_timeout) self.listener = dtls_srv sockname = self.listener.getsockname() self.host, self.port = sockname[0], sockname[1] def _serve_client(self, sock, credentials): addrinfo = sock.getpeername() if credentials: self.logger.info("welcome %s (%r)", addrinfo, credentials) else: self.logger.info("welcome %s", addrinfo) try: config = dict(self.protocol_config, credentials = credentials, endpoints = (sock.getsockname(), addrinfo), logger = self.logger) conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config) self._handle_connection(conn) finally: self.logger.info("goodbye %s", addrinfo) def _listen(self): if self.active: return #self.listener.listen(self.backlog) ##################### if not self.port: self.port = self.listener.getsockname()[1] self.logger.info('server started on [%s]:%s', self.host, self.port) self.active = True def accept(self): while self.active: try: print('Accepting new connections!') sock, addrinfo = self.listener.accept() except socket.timeout: pass except socket.error: ex = sys.exc_info()[1] if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN): pass else: raise EOFError() raise else: break if not self.active: return sock.setblocking(True) self.logger.info("accepted %s with fd %s", addrinfo, sock.fileno()) print("accepted %s with fd %s" % (addrinfo, sock.fileno())) self.clients.add(sock) self._accept_method(sock) def _accept_method(self, sock): addr = sock.getpeername() sock.setcookieparam(addr[0].encode()) with suppress(tls.HelloVerifyRequest): block(sock.do_handshake) sock, addr = sock.accept() sock.setcookieparam(addr[0].encode()) block(sock.do_handshake) spawn(self._authenticate_and_serve_client, sock)