From d67f48c05096aba3650663fb9b9b9052b1a23b5b Mon Sep 17 00:00:00 2001 From: Ray Brown Date: Mon, 10 Dec 2012 20:52:49 -0800 Subject: [PATCH] Interactive performance test suite The new module test_perf.py can be used to characterize protocol performance over a particular network link. Two stream protocols (TCP and SSL) and two datagram protocols (UDP and DTLS) are available for relative comparison. The module will run servers in its process, and will spawn clients either into separate processes, or, depending on command line options, will expect one or more remote clients to connect to it. In the latter case, jobs will be sent to such clients via a shared queue whenever the user selects a test suite. Stress testing under packet loss conditions revealed that that the OpenSSL library's compression feature needed to be explicitly disabled for DTLS: it evidently operates at the stream layer as opposed to the datagram layer, and packet loss would result in corruption among the packets that were successfully received, authenticated, and decrypted. Several performance improvements are included in this patch. --- dtls/openssl.py | 12 +- dtls/patch.py | 2 + dtls/sslconnection.py | 12 +- dtls/test/test_perf.py | 412 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 435 insertions(+), 3 deletions(-) create mode 100644 dtls/test/test_perf.py diff --git a/dtls/openssl.py b/dtls/openssl.py index c48ea60..2f03edb 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -66,6 +66,7 @@ else: BIO_NOCLOSE = 0x00 BIO_CLOSE = 0x01 SSLEAY_VERSION = 0 +SSL_OP_NO_COMPRESSION = 0x00020000 SSL_VERIFY_NONE = 0x00 SSL_VERIFY_PEER = 0x01 SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 0x02 @@ -89,6 +90,7 @@ CRYPTO_LOCK = 1 # SSL_CTRL_SET_SESS_CACHE_MODE = 44 SSL_CTRL_SET_READ_AHEAD = 41 +SSL_CTRL_OPTIONS = 32 BIO_CTRL_INFO = 3 BIO_CTRL_DGRAM_SET_CONNECTED = 32 BIO_CTRL_DGRAM_GET_PEER = 46 @@ -453,6 +455,7 @@ _subst = {c_long_parm: c_long} _sigs = {} __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "SSLEAY_VERSION", + "SSL_OP_NO_COMPRESSION", "SSL_VERIFY_NONE", "SSL_VERIFY_PEER", "SSL_VERIFY_FAIL_IF_NO_PEER_CERT", "SSL_VERIFY_CLIENT_ONCE", "SSL_SESS_CACHE_OFF", "SSL_SESS_CACHE_CLIENT", @@ -470,6 +473,7 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "BIO_dgram_get_peer", "BIO_dgram_set_peer", "BIO_set_nbio", "SSL_CTX_set_session_cache_mode", "SSL_CTX_set_read_ahead", + "SSL_CTX_set_options", "SSL_read", "SSL_write", "SSL_CTX_set_cookie_cb", "OBJ_obj2txt", "decode_ASN1_STRING", "ASN1_TIME_print", @@ -635,6 +639,10 @@ def SSL_CTX_set_read_ahead(ctx, m): # Returns the previous value of m _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_READ_AHEAD, m, None) +def SSL_CTX_set_options(ctx, options): + # Returns the new option bitmaks after adding the given options + _SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, options, None) + _rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), POINTER(c_uint)) _rint_voidp_ubytep_uint = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), c_uint) @@ -716,7 +724,9 @@ def SSL_read(ssl, length): return buf.raw[:res_len] def SSL_write(ssl, data): - if hasattr(data, "tobytes") and callable(data.tobytes): + if isinstance(data, str): + str_data = data + elif hasattr(data, "tobytes") and callable(data.tobytes): str_data = data.tobytes() else: str_data = str(data) diff --git a/dtls/patch.py b/dtls/patch.py index e922fe3..543f8d0 100644 --- a/dtls/patch.py +++ b/dtls/patch.py @@ -32,6 +32,8 @@ def do_patch(): global _orig_SSLSocket_init, _orig_get_server_certificate global ssl ssl = _ssl + if hasattr(ssl, "PROTOCOL_DTLSv1"): + return ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index c088302..e0b6a6c 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -175,6 +175,9 @@ class SSLConnection(object): def _config_ssl_ctx(self, verify_mode): SSL_CTX_set_verify(self._ctx.value, verify_mode) SSL_CTX_set_read_ahead(self._ctx.value, 1) + # Compression occurs at the stream layer now, leading to datagram + # corruption when packet loss occurs + SSL_CTX_set_options(self._ctx.value, SSL_OP_NO_COMPRESSION) if self._certfile: SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile) if self._keyfile: @@ -237,10 +240,14 @@ class SSLConnection(object): def _check_nbio(self): timeout = self._sock.gettimeout() - BIO_set_nbio(self._wbio.value, timeout is not None) + if self._wbio_nb != timeout is not None: + BIO_set_nbio(self._wbio.value, timeout is not None) + self._wbio_nb = timeout is not None if self._wbio is not self._rbio: timeout = self._rsock.gettimeout() - BIO_set_nbio(self._rbio.value, timeout is not None) + if self._rbio_nb != timeout is not None: + BIO_set_nbio(self._rbio.value, timeout is not None) + self._rbio_nb = timeout is not None return timeout # read channel timeout def _wrap_socket_library_call(self, call, timeout_error): @@ -314,6 +321,7 @@ class SSLConnection(object): self._suppress_ragged_eofs = suppress_ragged_eofs self._ciphers = ciphers self._handshake_done = False + self._wbio_nb = self._rbio_nb = False if isinstance(sock, SSLConnection): post_init = self._copy_server() diff --git a/dtls/test/test_perf.py b/dtls/test/test_perf.py new file mode 100644 index 0000000..0479438 --- /dev/null +++ b/dtls/test/test_perf.py @@ -0,0 +1,412 @@ +# Performance tests for PyDTLS. Written by Ray Brown. +"""PyDTLS performance tests + +This module implements relative performance testing of throughput for the +PyDTLS package. Throughput for the following transports can be compared: + + * Python standard library stream transport (ssl module) + * PyDTLS datagram transport + * PyDTLS datagram transport with thread locking callbacks disabled + * PyDTLS datagram transport with demux type forced to routing demux +""" + +import socket +import errno +import ssl +import sys +import time +from argparse import ArgumentParser, ArgumentTypeError +from os import path, urandom +from timeit import timeit +from select import select +from multiprocessing import Process +from dtls import do_patch + +AF_INET4_6 = socket.AF_INET +CERTFILE = path.join(path.dirname(__file__), "certs", "keycert.pem") +CHUNK_SIZE = 1459 +CHUNKS = 150000 +CHUNKS_PER_DOT = 500 +COMM_KEY = "tronje%T577&kkjLp" + +# +# Traffic handler: required for servicing the root socket if the routing demux +# is used; only waits for traffic on the data socket with +# the osnet demux, as well as streaming sockets +# + +def handle_traffic(data_sock, listen_sock, err): + assert data_sock + assert err in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE) + readers = [] + writers = [] + if listen_sock: + readers.append(listen_sock) + if err == ssl.SSL_ERROR_WANT_READ: + readers.append(data_sock) + else: + writers.append(data_sock) + while True: + read_ready, write_ready, exc_ready = select(readers, writers, [], 5) + if not read_ready and not write_ready: + raise ssl.SSLError("timed out") + if data_sock in read_ready or data_sock in write_ready: + break + assert listen_sock in read_ready + acc_ret = listen_sock.accept() + assert acc_ret is None # test does not attempt multiple connections + +# +# Transfer functions: transfer data on non-blocking sockets; written to work +# properly for stream as well as message-based protocols +# + +fill = urandom(CHUNK_SIZE) + +def transfer_out(sock, listen_sock=None, marker=False): + max_i_len = 10 + start_char = "t" if marker else "s" + for i in xrange(CHUNKS): + prefix = start_char + str(i) + ":" + pad_prefix = prefix + "b" * (max_i_len - len(prefix)) + message = pad_prefix + fill[:CHUNK_SIZE - max_i_len - 1] + "e" + count = 0 + while count < CHUNK_SIZE: + try: + count += sock.send(message[count:]) + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + handle_traffic(sock, listen_sock, err.args[0]) + else: + raise + except socket.error as err: + if err.errno == errno.EWOULDBLOCK: + handle_traffic(sock, None, ssl.SSL_ERROR_WANT_WRITE) + else: + raise + if not i % CHUNKS_PER_DOT: + sys.stdout.write('.') + sys.stdout.flush() + print + +def transfer_in(sock, listen_sock=None): + drops = 0 + pack_seq = -1 + i = 0 + try: + sock.getpeername() + except: + peer_set = False + else: + peer_set = True + while pack_seq + 1 < CHUNKS: + pack = "" + while len(pack) < CHUNK_SIZE: + try: + if isinstance(sock, ssl.SSLSocket): + segment = sock.recv(CHUNK_SIZE - len(pack)) + else: + segment, addr = sock.recvfrom(CHUNK_SIZE - len(pack)) + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + try: + handle_traffic(sock, listen_sock, err.args[0]) + except ssl.SSLError as err: + if err.message == "timed out": + break + raise + else: + raise + except socket.error as err: + if err.errno == errno.EWOULDBLOCK: + try: + handle_traffic(sock, None, ssl.SSL_ERROR_WANT_READ) + except ssl.SSLError as err: + if err.message == "timed out": + break + raise + else: + raise + else: + pack += segment + if not peer_set: + sock.connect(addr) + peer_set = True + # Do not try to assembly packets from datagrams + if sock.type == socket.SOCK_DGRAM: + break + if len(pack) < CHUNK_SIZE or pack[0] == "t": + break + if pack[0] != "s" or pack[-1] != "e": + raise Exception("Corrupt message received") + next_seq = int(pack[1:pack.index(':')]) + if next_seq > pack_seq: + drops += next_seq - pack_seq - 1 + pack_seq = next_seq + if not i % CHUNKS_PER_DOT: + sys.stdout.write('.') + sys.stdout.flush() + i += 1 + drops += CHUNKS - 1 - pack_seq + print + return drops + +# +# Single-threaded server +# + +def server(sock_type, do_wrap, listen_addr): + sock = socket.socket(AF_INET4_6, sock_type) + sock.bind(listen_addr) + if do_wrap: + wrap = ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE, + do_handshake_on_connect=False, + ciphers="NULL") + wrap.listen(0) + else: + wrap = sock + if sock_type == socket.SOCK_STREAM: + wrap.listen(0) + yield wrap.getsockname() + if do_wrap or sock_type == socket.SOCK_STREAM: + conn = wrap.accept()[0] + else: + conn = wrap + wrap.setblocking(False) + conn.setblocking(False) + class InResult(object): pass + def _transfer_in(): + InResult.drops = transfer_in(conn, wrap) + in_time = timeit(_transfer_in, number=1) + yield in_time, InResult.drops + out_time = timeit(lambda: transfer_out(conn, wrap), number=1) + # Inform the client that we are done, in case it has missed the final chunk + if sock_type == socket.SOCK_DGRAM: + global CHUNKS, CHUNK_SIZE + CHUNKS_sav = CHUNKS + CHUNK_SIZE_sav = CHUNK_SIZE + try: + CHUNKS = 5 + CHUNK_SIZE = 10 + for _ in range(10): + try: + transfer_out(conn, wrap, True) + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_SYSCALL: + break + else: + raise + except socket.error as err: + if err.errno == errno.ECONNREFUSED: + break + else: + raise + time.sleep(0.2) + finally: + CHUNKS = CHUNKS_sav + CHUNK_SIZE = CHUNK_SIZE_sav + conn.shutdown(socket.SHUT_RDWR) + conn.close() + wrap.close() + yield out_time + +# +# Client, launched into a separate process +# + +def client(sock_type, do_wrap, listen_addr): + sock = socket.socket(AF_INET4_6, sock_type) + if do_wrap: + wrap = ssl.wrap_socket(sock, ciphers="NULL") + else: + wrap = sock + wrap.connect(listen_addr) + transfer_out(wrap) + drops = transfer_in(wrap) + wrap.shutdown(socket.SHUT_RDWR) + wrap.close() + return drops + +# +# Client manager - remote clients, run in a separate process +# + +def make_client_manager(): + # Create the global client manager class in servers configured as client + # managers + class ClientManager(object): + from Queue import Queue + + queue = Queue() + clients = -1 # creator does not count + + @classmethod + def get_queue(cls): + cls.clients += 1 + return cls.queue + + @classmethod + def release_clients(cls): + def wait_queue_empty(fail_return): + waitcount = 5 + while not cls.queue.empty() and waitcount: + time.sleep(1) + waitcount -= 1 + if not cls.queue.empty(): + # Clients are already dead or stuck + return fail_return + # Wait a moment for the queue to empty + wait_queue_empty("No live clients detected") + for _ in range(cls.clients): + cls.queue.put("STOP") + # Wait for all stop messages to be retrieved + wait_queue_empty("Not all clients responded to stop signal") + return "Client release succeeded" + globals()["ClientManager"] = ClientManager + +def get_queue(): + return ClientManager.get_queue() + +def release_clients(): + return ClientManager.release_clients() + +MANAGER = None +QUEUE = None + +def start_client_manager(port): + from multiprocessing.managers import BaseManager + global MANAGER, QUEUE + make_client_manager() + class Manager(BaseManager): pass + Manager.register("get_queue", get_queue) + Manager.register("release_clients", release_clients) + MANAGER = Manager(('', port), COMM_KEY) + MANAGER.start(make_client_manager) + QUEUE = MANAGER.get_queue() + +def stop_client_manager(): + global MANAGER, QUEUE + QUEUE = None + MANAGER.release_clients() + MANAGER.shutdown() + MANAGER = None + +def remote_client(manager_address): + from multiprocessing.managers import BaseManager + do_patch() + class Manager(BaseManager): pass + Manager.register("get_queue") + manager = Manager(manager_address, COMM_KEY) + manager.connect() + queue = manager.get_queue() + print "Client connected; waiting for job..." + while True: + command = queue.get() + if command == "STOP": + break + command = command[:-1] + [(manager_address[0], command[-1][1])] + print "Starting job: " + str(command) + drops = client(*command) + print "%d drops" % drops + print "Job completed; waiting for next job..." + +# +# Test runner +# + +def run_test(server_args, client_args): + if QUEUE: + # bind to all interfaces, for remote clients + listen_addr = '', 0 + else: + # bind to loopback only, for local clients + listen_addr = 'localhost', 0 + svr = iter(server(*server_args, listen_addr=listen_addr)) + listen_addr = svr.next() + listen_addr = 'localhost', listen_addr[1] + client_args = list(client_args) + client_args.append(listen_addr) + if QUEUE: + QUEUE.put(client_args) + else: + proc = Process(target=client, args=client_args) + proc.start() + in_size = CHUNK_SIZE * CHUNKS / 2**20 + out_size = CHUNK_SIZE * CHUNKS / 2**20 + print "Starting inbound: %dMiB" % in_size + svr_in_time, drops = svr.next() + print "Inbound: %.3f seconds, %dMiB/s, %d drops" % ( + svr_in_time, in_size / svr_in_time, drops) + print "Starting outbound: %dMiB" % out_size + svr_out_time = svr.next() + print "Outbound: %.3f seconds, %dMiB/s" % ( + svr_out_time, out_size / svr_out_time) + if not QUEUE: + proc.join() + print "Combined: %.3f seconds, %dMiB/s" % ( + svr_out_time + svr_in_time, + (in_size + out_size) / (svr_in_time + svr_out_time)) + +# +# Main entry point +# + +if __name__ == "__main__": + def port(string): + val = int(string) + if val < 1 or val > 2**16: + raise ArgumentTypeError("%d is an invalid port number" % val) + return val + def endpoint(string): + addr = string.split(':') + if len(addr) != 2: + raise ArgumentTypeError("%s is not a valid host endpoint" % string) + addr[1] = port(addr[1]) + socket.getaddrinfo(addr[0], addr[1], socket.AF_INET) + return tuple(addr) + parser = ArgumentParser() + parser.add_argument("-s", "--server", type=port, metavar="PORT", + help="local server port for remote clients") + parser.add_argument("-c", "--client", type=endpoint, metavar="ENDPOINT", + help="remote server endpoint for this client") + args = parser.parse_args() + if args.client: + remote_client(args.client) + sys.exit() + if args.server: + start_client_manager(args.server) + suites = { + "Raw TCP": (socket.SOCK_STREAM, False), + "Raw UDP": (socket.SOCK_DGRAM, False), + "SSL (TCP)": (socket.SOCK_STREAM, True), + "DTLS (UDP)": (socket.SOCK_DGRAM, True), + } + selector = { + 0: "Exit", + 1: "Raw TCP", + 2: "Raw UDP", + 3: "SSL (TCP)", + 4: "DTLS (UDP)", + } + do_patch() + while True: + print "\nSelect protocol:\n" + for key in sorted(selector): + print "\t" + str(key) + ": " + selector[key] + try: + choice = raw_input("\nProtocol: ") + choice = int(choice) + if choice < 0 or choice >= len(selector): + raise ValueError("Invalid selection input") + except (ValueError, OverflowError): + print "Invalid selection input" + continue + except EOFError: + break + if not choice: + break + run_test(suites[selector[choice]], suites[selector[choice]]) + if args.server: + stop_client_manager()