diff --git a/mods/rpyc_dtls.py b/mods/rpyc_dtls.py new file mode 100644 index 0000000..224277d --- /dev/null +++ b/mods/rpyc_dtls.py @@ -0,0 +1,105 @@ +import rpyc +from rpyc.utils.server import ThreadedServer, spawn +from rpyc.core.stream import SocketStream +import socket +from mbedtls import tls +import datetime as dt +from mbedtls import hash as hashlib +from mbedtls import pk +from mbedtls import x509 +from uuid import uuid4 +from contextlib import suppress +import sys + +def block(callback, *args, **kwargs): + while True: + with suppress(tls.WantReadError, tls.WantWriteError): + return callback(*args, **kwargs) + +class DTLSCerts: + def __init__(self): + now = dt.datetime.utcnow() + self.ca0_key = pk.RSA() + _ = self.ca0_key.generate() + ca0_csr = x509.CSR.new(self.ca0_key, 'CN=Black mamba thrusted CA', hashlib.sha256()) + self.ca0_crt = x509.CRT.selfsign( + ca0_csr, self.ca0_key, + not_before=now, not_after=now + dt.timedelta(days=3650), + serial_number=0x123456, + basic_constraints=x509.BasicConstraints(True, 1)) + self.ca1_key = pk.ECC() + _ = self.ca1_key.generate() + ca1_csr = x509.CSR.new(self.ca1_key, 'CN=Black mamba intermediate CA', hashlib.sha256()) + self.ca1_crt = self.ca0_crt.sign( + ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456, + basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3)) + def server_cert(self): + now = dt.datetime.utcnow() + ee0_key = pk.ECC() + _ = ee0_key.generate() + ee0_csr = x509.CSR.new(ee0_key, f'CN=Black mamba peer [{uuid4().hex}]', hashlib.sha256()) + ee0_crt = self.ca1_crt.sign( + ee0_csr, self.ca1_key, now, now + dt.timedelta(days=3650), 0x987654) + return ee0_crt, ee0_key + +dtls_certs = DTLSCerts() +trust_store = tls.TrustStore() +trust_store.add(dtls_certs.ca0_crt) + +class DTLSSocketStream(SocketStream): + @classmethod + def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs): + if kwargs.pop('ipv6', False): + family = socket.AF_INET6 + else: + family = socket.AF_INET + dtls_cli_ctx = tls.ClientContext(tls.DTLSConfiguration( + trust_store=trust_store, + validate_certificates=False, + )) + dtls_cli = dtls_cli_ctx.wrap_socket( + socket.socket(family, socket.SOCK_DGRAM), + server_hostname=None, + ) + dtls_cli.settimeout(timeout) + dtls_cli.connect((host, port)) + block(dtls_cli.do_handshake) + return cls(dtls_cli) + def read(self, count): + return block(SocketStream.read(self, count)) + def write(self, data): + block(SocketStream.write(self, data)) + +class DTLSThreadedServer(ThreadedServer): + def dtls(self, listener_timeout = 0.5, reuse_addr = True): + self.listener.close() + srv_crt, srv_key = dtls_certs.server_cert() + dtls_srv_ctx = tls.ServerContext(tls.DTLSConfiguration( + trust_store=trust_store, + certificate_chain=([srv_crt, dtls_certs.ca1_crt], srv_key), + validate_certificates=False, + )) + dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) + if reuse_addr and sys.platform != 'win32': + dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + dtls_srv.bind((self.host, self.port)) + dtls_srv.settimeout(listener_timeout) + self.listener = dtls_srv + def _listen(self): + if self.active: + return + if not self.port: + self.port = self.listener.getsockname()[1] + self.logger.info('server started on [%s]:%s', self.host, self.port) + self.active = True + def _authenticate_and_serve_client(self, sock): + addr = sock.getpeername() + sock.setcookieparam(addr[0].encode()) + with suppress(tls.HelloVerifyRequest): + block(sock.do_handshake) + sock, addr = sock.accept() + sock.setcookieparam(addr[0].encode()) + block(sock.do_handshake) + ThreadedServer._authenticate_and_serve_client(self, sock) + +