black-mamba/mods/rpyc_dtls.py

106 lines
4.0 KiB
Python
Raw Normal View History

import rpyc
from rpyc.utils.server import ThreadedServer, spawn
from rpyc.core.stream import SocketStream
import socket
from mbedtls import tls
import datetime as dt
from mbedtls import hash as hashlib
from mbedtls import pk
from mbedtls import x509
from uuid import uuid4
from contextlib import suppress
import sys
def block(callback, *args, **kwargs):
while True:
with suppress(tls.WantReadError, tls.WantWriteError):
return callback(*args, **kwargs)
class DTLSCerts:
def __init__(self):
now = dt.datetime.utcnow()
self.ca0_key = pk.RSA()
_ = self.ca0_key.generate()
ca0_csr = x509.CSR.new(self.ca0_key, 'CN=Black mamba thrusted CA', hashlib.sha256())
self.ca0_crt = x509.CRT.selfsign(
ca0_csr, self.ca0_key,
not_before=now, not_after=now + dt.timedelta(days=3650),
serial_number=0x123456,
basic_constraints=x509.BasicConstraints(True, 1))
self.ca1_key = pk.ECC()
_ = self.ca1_key.generate()
ca1_csr = x509.CSR.new(self.ca1_key, 'CN=Black mamba intermediate CA', hashlib.sha256())
self.ca1_crt = self.ca0_crt.sign(
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
def server_cert(self):
now = dt.datetime.utcnow()
ee0_key = pk.ECC()
_ = ee0_key.generate()
ee0_csr = x509.CSR.new(ee0_key, f'CN=Black mamba peer [{uuid4().hex}]', hashlib.sha256())
ee0_crt = self.ca1_crt.sign(
ee0_csr, self.ca1_key, now, now + dt.timedelta(days=3650), 0x987654)
return ee0_crt, ee0_key
dtls_certs = DTLSCerts()
trust_store = tls.TrustStore()
trust_store.add(dtls_certs.ca0_crt)
class DTLSSocketStream(SocketStream):
@classmethod
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
if kwargs.pop('ipv6', False):
family = socket.AF_INET6
else:
family = socket.AF_INET
dtls_cli_ctx = tls.ClientContext(tls.DTLSConfiguration(
trust_store=trust_store,
validate_certificates=False,
))
dtls_cli = dtls_cli_ctx.wrap_socket(
socket.socket(family, socket.SOCK_DGRAM),
server_hostname=None,
)
dtls_cli.settimeout(timeout)
dtls_cli.connect((host, port))
block(dtls_cli.do_handshake)
return cls(dtls_cli)
def read(self, count):
return block(SocketStream.read(self, count))
def write(self, data):
block(SocketStream.write(self, data))
class DTLSThreadedServer(ThreadedServer):
def dtls(self, listener_timeout = 0.5, reuse_addr = True):
self.listener.close()
srv_crt, srv_key = dtls_certs.server_cert()
dtls_srv_ctx = tls.ServerContext(tls.DTLSConfiguration(
trust_store=trust_store,
certificate_chain=([srv_crt, dtls_certs.ca1_crt], srv_key),
validate_certificates=False,
))
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
if reuse_addr and sys.platform != 'win32':
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
dtls_srv.bind((self.host, self.port))
dtls_srv.settimeout(listener_timeout)
self.listener = dtls_srv
def _listen(self):
if self.active:
return
if not self.port:
self.port = self.listener.getsockname()[1]
self.logger.info('server started on [%s]:%s', self.host, self.port)
self.active = True
def _authenticate_and_serve_client(self, sock):
addr = sock.getpeername()
sock.setcookieparam(addr[0].encode())
with suppress(tls.HelloVerifyRequest):
block(sock.do_handshake)
sock, addr = sock.accept()
sock.setcookieparam(addr[0].encode())
block(sock.do_handshake)
ThreadedServer._authenticate_and_serve_client(self, sock)