diff --git a/mods/rpyc_dtls.py b/mods/rpyc_dtls.py index 7151af5..9dbf606 100644 --- a/mods/rpyc_dtls.py +++ b/mods/rpyc_dtls.py @@ -66,7 +66,7 @@ cli_ctx_conf = tls.DTLSConfiguration( trust_store=trust_store, validate_certificates=False, ) -MAX_IO_CHUNK = 20971520 +MAX_IO_CHUNK = 8192 class DTLSSocketStream(SocketStream): MAX_IO_CHUNK = MAX_IO_CHUNK @@ -90,7 +90,7 @@ class DTLSSocketStream(SocketStream): def read(self, count): while True: try: - buf = block(self.sock.recv, min(self.MAX_IO_CHUNK, count)) + buf = block(self.sock.recv, self.MAX_IO_CHUNK) except socket.timeout: continue except socket.error: @@ -118,14 +118,30 @@ class DTLSSocketStream(SocketStream): class DTLSChannel(Channel): MAX_IO_CHUNK = MAX_IO_CHUNK def recv(self): - raw_data = self.stream.read(self.MAX_IO_CHUNK) - header = raw_data[:self.FRAME_HEADER.size] - raw_data = raw_data[self.FRAME_HEADER.size:] + header = self.stream.read(self.MAX_IO_CHUNK) length, compressed = self.FRAME_HEADER.unpack(header) - data = raw_data[:length] + length += len(self.FLUSHER) + data = b'' + while length: + dat = self.stream.read(self.MAX_IO_CHUNK) + data += dat + length -= len(dat) + data = data[:-len(self.FLUSHER)] if compressed: data = zlib.decompress(data) return data + def send(self, data): + if self.compress and len(data) > self.COMPRESSION_THRESHOLD: + compressed = 1 + data = zlib.compress(data, self.COMPRESSION_LEVEL) + else: + compressed = 0 + header = self.FRAME_HEADER.pack(len(data), compressed) + self.stream.write(header) + data = data + self.FLUSHER + data = [data[i:i + self.MAX_IO_CHUNK] for i in range(0, len(data), self.MAX_IO_CHUNK)] + for chunk in data: + self.stream.write(chunk) def connect_stream(stream, service=rpyc.VoidService, config={}): return connect_channel(DTLSChannel(stream), service=service, config=config) @@ -205,18 +221,17 @@ class DTLSThreadedServer(ThreadedServer): return sock.setblocking(True) - self.logger.info("accepted %s with fd %s", addrinfo, sock.fileno()) - print("accepted %s with fd %s" % (addrinfo, sock.fileno())) - self.clients.add(sock) - self._accept_method(sock) - def _accept_method(self, sock): addr = sock.getpeername() sock.setcookieparam(addr[0].encode()) with suppress(tls.HelloVerifyRequest): block(sock.do_handshake) - sock, addr = sock.accept() - sock.setcookieparam(addr[0].encode()) - block(sock.do_handshake) - spawn(self._authenticate_and_serve_client, sock) - + sock2, addr = sock.accept() + sock.close() + sock2.setblocking(True) + sock2.setcookieparam(addr[0].encode()) + block(sock2.do_handshake) + self.logger.info("accepted %s with fd %s", addrinfo, sock2.fileno()) + print("accepted %s with fd %s" % (addrinfo, sock2.fileno())) + self.clients.add(sock2) + self._accept_method(sock2)