223 lines
8.5 KiB
Python
223 lines
8.5 KiB
Python
# Based on rpyc (4.0.2)
|
|
|
|
# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!!
|
|
|
|
import rpyc
|
|
from rpyc.utils.server import ThreadedServer, spawn
|
|
from rpyc.core import SocketStream, Channel
|
|
from rpyc.core.stream import retry_errnos
|
|
from rpyc.utils.factory import connect_channel
|
|
import socket
|
|
from mbedtls import tls
|
|
import datetime as dt
|
|
from mbedtls import hash as hashlib
|
|
from mbedtls import pk
|
|
from mbedtls import x509
|
|
from uuid import uuid4
|
|
from contextlib import suppress
|
|
import sys
|
|
from rpyc.lib import safe_import
|
|
zlib = safe_import("zlib")
|
|
import errno
|
|
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
|
|
|
|
def block(callback, *args, **kwargs):
|
|
while True:
|
|
with suppress(tls.WantReadError, tls.WantWriteError):
|
|
return callback(*args, **kwargs)
|
|
|
|
class DTLSCerts:
|
|
def __init__(self):
|
|
now = dt.datetime.utcnow()
|
|
self.ca0_key = pk.RSA()
|
|
_ = self.ca0_key.generate()
|
|
ca0_csr = x509.CSR.new(self.ca0_key, 'CN=Black mamba thrusted CA', hashlib.sha256())
|
|
self.ca0_crt = x509.CRT.selfsign(
|
|
ca0_csr, self.ca0_key,
|
|
not_before=now, not_after=now + dt.timedelta(days=3650),
|
|
serial_number=0x123456,
|
|
basic_constraints=x509.BasicConstraints(True, 1))
|
|
self.ca1_key = pk.ECC()
|
|
_ = self.ca1_key.generate()
|
|
ca1_csr = x509.CSR.new(self.ca1_key, 'CN=Black mamba intermediate CA', hashlib.sha256())
|
|
self.ca1_crt = self.ca0_crt.sign(
|
|
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
|
|
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
|
|
self.srv_crt, self.srv_key = self.server_cert()
|
|
def server_cert(self):
|
|
now = dt.datetime.utcnow()
|
|
ee0_key = pk.ECC()
|
|
_ = ee0_key.generate()
|
|
ee0_csr = x509.CSR.new(ee0_key, f'CN=Black mamba peer [{uuid4().hex}]', hashlib.sha256())
|
|
ee0_crt = self.ca1_crt.sign(
|
|
ee0_csr, self.ca1_key, now, now + dt.timedelta(days=3650), 0x987654)
|
|
return ee0_crt, ee0_key
|
|
|
|
dtls_certs = DTLSCerts()
|
|
trust_store = tls.TrustStore()
|
|
trust_store.add(dtls_certs.ca0_crt)
|
|
|
|
srv_ctx_conf = tls.DTLSConfiguration(
|
|
trust_store=trust_store,
|
|
certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key),
|
|
validate_certificates=False,
|
|
)
|
|
cli_ctx_conf = tls.DTLSConfiguration(
|
|
trust_store=trust_store,
|
|
validate_certificates=False,
|
|
)
|
|
MAX_IO_CHUNK = 20971520
|
|
|
|
class DTLSSocketStream(SocketStream):
|
|
MAX_IO_CHUNK = MAX_IO_CHUNK
|
|
@classmethod
|
|
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
|
|
if kwargs.pop('ipv6', False):
|
|
family = socket.AF_INET6
|
|
else:
|
|
family = socket.AF_INET
|
|
#tls._enable_debug_output(cli_ctx_conf)
|
|
#tls._set_debug_level(10)
|
|
dtls_cli_ctx = tls.ClientContext(cli_ctx_conf)
|
|
dtls_cli = dtls_cli_ctx.wrap_socket(
|
|
socket.socket(family, socket.SOCK_DGRAM),
|
|
server_hostname=None,
|
|
)
|
|
dtls_cli.settimeout(timeout)
|
|
dtls_cli.connect((host, port))
|
|
block(dtls_cli.do_handshake)
|
|
return cls(dtls_cli)
|
|
def read(self, count):
|
|
while True:
|
|
try:
|
|
buf = block(self.sock.recv, min(self.MAX_IO_CHUNK, count))
|
|
except socket.timeout:
|
|
continue
|
|
except socket.error:
|
|
ex = sys.exc_info()[1]
|
|
if get_exc_errno(ex) in retry_errnos:
|
|
# windows just has to be a bitch
|
|
# inpos: I agree
|
|
continue
|
|
self.close()
|
|
raise EOFError(ex)
|
|
else:
|
|
break
|
|
if not buf:
|
|
self.close()
|
|
raise EOFError("connection closed by peer")
|
|
return buf
|
|
def write(self, data):
|
|
try:
|
|
_ = block(self.sock.send, data[:self.MAX_IO_CHUNK])
|
|
except socket.error:
|
|
ex = sys.exc_info()[1]
|
|
self.close()
|
|
raise EOFError(ex)
|
|
|
|
class DTLSChannel(Channel):
|
|
MAX_IO_CHUNK = MAX_IO_CHUNK
|
|
def recv(self):
|
|
raw_data = self.stream.read(self.MAX_IO_CHUNK)
|
|
header = raw_data[:self.FRAME_HEADER.size]
|
|
raw_data = raw_data[self.FRAME_HEADER.size:]
|
|
length, compressed = self.FRAME_HEADER.unpack(header)
|
|
data = raw_data[:length]
|
|
if compressed:
|
|
data = zlib.decompress(data)
|
|
return data
|
|
|
|
def connect_stream(stream, service=rpyc.VoidService, config={}):
|
|
return connect_channel(DTLSChannel(stream), service=service, config=config)
|
|
|
|
def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
|
|
cert_reqs=None, ssl_version=None, ciphers=None,
|
|
service=rpyc.VoidService, config={}, ipv6=False, keepalive=False):
|
|
ssl_kwargs = {'server_side' : False}
|
|
if ciphers is not None:
|
|
ssl_kwargs['ciphers'] = ciphers
|
|
s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive)
|
|
return connect_stream(s, service, config)
|
|
|
|
class DTLSThreadedServer(ThreadedServer):
|
|
def __init__(self, service, hostname = "", ipv6 = False, port = 0,
|
|
backlog = 10, reuse_addr = True, authenticator = None, registrar = None,
|
|
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
|
|
socket_path = None):
|
|
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
|
|
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
|
|
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
|
|
socket_path=socket_path)
|
|
self.listener.close()
|
|
|
|
self.host = hostname
|
|
self.port = port
|
|
#tls._enable_debug_output(srv_ctx_conf)
|
|
#tls._set_debug_level(10)
|
|
dtls_srv_ctx = tls.ServerContext(srv_ctx_conf)
|
|
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
|
|
if reuse_addr and sys.platform != 'win32':
|
|
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
dtls_srv.bind((hostname, port))
|
|
# dtls_srv.settimeout(listener_timeout)
|
|
self.listener = dtls_srv
|
|
sockname = self.listener.getsockname()
|
|
self.host, self.port = sockname[0], sockname[1]
|
|
def _serve_client(self, sock, credentials):
|
|
addrinfo = sock.getpeername()
|
|
if credentials:
|
|
self.logger.info("welcome %s (%r)", addrinfo, credentials)
|
|
else:
|
|
self.logger.info("welcome %s", addrinfo)
|
|
try:
|
|
config = dict(self.protocol_config, credentials = credentials,
|
|
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
|
|
conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config)
|
|
self._handle_connection(conn)
|
|
finally:
|
|
self.logger.info("goodbye %s", addrinfo)
|
|
def _listen(self):
|
|
if self.active:
|
|
return
|
|
#self.listener.listen(self.backlog) #####################
|
|
if not self.port:
|
|
self.port = self.listener.getsockname()[1]
|
|
self.logger.info('server started on [%s]:%s', self.host, self.port)
|
|
self.active = True
|
|
def accept(self):
|
|
while self.active:
|
|
try:
|
|
print('Accepting new connections!')
|
|
sock, addrinfo = self.listener.accept()
|
|
except socket.timeout:
|
|
pass
|
|
except socket.error:
|
|
ex = sys.exc_info()[1]
|
|
if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN):
|
|
pass
|
|
else:
|
|
raise EOFError()
|
|
raise
|
|
else:
|
|
break
|
|
|
|
if not self.active:
|
|
return
|
|
|
|
sock.setblocking(True)
|
|
self.logger.info("accepted %s with fd %s", addrinfo, sock.fileno())
|
|
print("accepted %s with fd %s" % (addrinfo, sock.fileno()))
|
|
self.clients.add(sock)
|
|
self._accept_method(sock)
|
|
def _accept_method(self, sock):
|
|
addr = sock.getpeername()
|
|
sock.setcookieparam(addr[0].encode())
|
|
with suppress(tls.HelloVerifyRequest):
|
|
block(sock.do_handshake)
|
|
sock, addr = sock.accept()
|
|
sock.setcookieparam(addr[0].encode())
|
|
block(sock.do_handshake)
|
|
spawn(self._authenticate_and_serve_client, sock)
|
|
|
|
|