import rpyc from rpyc.utils.server import ThreadedServer, spawn from rpyc.core.stream import SocketStream import socket from mbedtls import tls import datetime as dt from mbedtls import hash as hashlib from mbedtls import pk from mbedtls import x509 from uuid import uuid4 from contextlib import suppress import sys def block(callback, *args, **kwargs): while True: with suppress(tls.WantReadError, tls.WantWriteError): return callback(*args, **kwargs) class DTLSCerts: def __init__(self): now = dt.datetime.utcnow() self.ca0_key = pk.RSA() _ = self.ca0_key.generate() ca0_csr = x509.CSR.new(self.ca0_key, 'CN=Black mamba thrusted CA', hashlib.sha256()) self.ca0_crt = x509.CRT.selfsign( ca0_csr, self.ca0_key, not_before=now, not_after=now + dt.timedelta(days=3650), serial_number=0x123456, basic_constraints=x509.BasicConstraints(True, 1)) self.ca1_key = pk.ECC() _ = self.ca1_key.generate() ca1_csr = x509.CSR.new(self.ca1_key, 'CN=Black mamba intermediate CA', hashlib.sha256()) self.ca1_crt = self.ca0_crt.sign( ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456, basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3)) def server_cert(self): now = dt.datetime.utcnow() ee0_key = pk.ECC() _ = ee0_key.generate() ee0_csr = x509.CSR.new(ee0_key, f'CN=Black mamba peer [{uuid4().hex}]', hashlib.sha256()) ee0_crt = self.ca1_crt.sign( ee0_csr, self.ca1_key, now, now + dt.timedelta(days=3650), 0x987654) return ee0_crt, ee0_key dtls_certs = DTLSCerts() trust_store = tls.TrustStore() trust_store.add(dtls_certs.ca0_crt) class DTLSSocketStream(SocketStream): @classmethod def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs): if kwargs.pop('ipv6', False): family = socket.AF_INET6 else: family = socket.AF_INET dtls_cli_ctx = tls.ClientContext(tls.DTLSConfiguration( trust_store=trust_store, validate_certificates=False, )) dtls_cli = dtls_cli_ctx.wrap_socket( socket.socket(family, socket.SOCK_DGRAM), server_hostname=None, ) dtls_cli.settimeout(timeout) dtls_cli.connect((host, port)) block(dtls_cli.do_handshake) return cls(dtls_cli) def read(self, count): return block(SocketStream.read(self, count)) def write(self, data): block(SocketStream.write(self, data)) class DTLSThreadedServer(ThreadedServer): def dtls(self, listener_timeout = 0.5, reuse_addr = True): self.listener.close() srv_crt, srv_key = dtls_certs.server_cert() dtls_srv_ctx = tls.ServerContext(tls.DTLSConfiguration( trust_store=trust_store, certificate_chain=([srv_crt, dtls_certs.ca1_crt], srv_key), validate_certificates=False, )) dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) if reuse_addr and sys.platform != 'win32': dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) dtls_srv.bind((self.host, self.port)) dtls_srv.settimeout(listener_timeout) self.listener = dtls_srv def _listen(self): if self.active: return if not self.port: self.port = self.listener.getsockname()[1] self.logger.info('server started on [%s]:%s', self.host, self.port) self.active = True def _authenticate_and_serve_client(self, sock): addr = sock.getpeername() sock.setcookieparam(addr[0].encode()) with suppress(tls.HelloVerifyRequest): block(sock.do_handshake) sock, addr = sock.accept() sock.setcookieparam(addr[0].encode()) block(sock.do_handshake) ThreadedServer._authenticate_and_serve_client(self, sock)