From 9bee9d049fda5a1fc22090f5d10fe1e9f6af1a97 Mon Sep 17 00:00:00 2001 From: Ray Brown Date: Sun, 25 Nov 2012 13:44:55 -0800 Subject: [PATCH] Echo server connection timeout and monitoring Prior to this change, the unit tests' echo servers' connections would sometimes linger beyond server termination. A timeout mechanism is now implemented, which will terminate a connection and clean up its resources when the timeout is reached. For the purpose of unit testing, test echo servers now assert that all connections have been terminated when the servers are closed. For threaded echo servers, the timeout mechanism involves the use of sockets with timeouts instead of blocking sockets. This required an implementation with the proper handling of timeout sockets at the sslconnection level. --- dtls/err.py | 8 +++-- dtls/sslconnection.py | 58 +++++++++++++++++++++++--------- dtls/test/echo_seq.py | 9 ++--- dtls/test/unit.py | 78 +++++++++++++++++++++++++++++++++++-------- 4 files changed, 119 insertions(+), 34 deletions(-) diff --git a/dtls/err.py b/dtls/err.py index 9dbe563..6c5fc40 100644 --- a/dtls/err.py +++ b/dtls/err.py @@ -26,8 +26,10 @@ ERR_BOTH_KEY_CERT_FILES = 500 ERR_BOTH_KEY_CERT_FILES_SVR = 298 ERR_NO_CERTS = 331 ERR_NO_CIPHER = 501 -ERR_HANDSHAKE_TIMEOUT = 502 -ERR_PORT_UNREACHABLE = 503 +ERR_READ_TIMEOUT = 502 +ERR_WRITE_TIMEOUT = 503 +ERR_HANDSHAKE_TIMEOUT = 504 +ERR_PORT_UNREACHABLE = 505 ERR_COOKIE_MISMATCH = 0x1408A134 @@ -88,6 +90,8 @@ _ssl_errors = { ERR_BOTH_KEY_CERT_FILES_SVR: "Both the key & certificate files must be " + \ "specified for server-side operation", ERR_NO_CIPHER: "No cipher can be selected.", + ERR_READ_TIMEOUT: "The read operation timed out", + ERR_WRITE_TIMEOUT: "The write operation timed out", ERR_HANDSHAKE_TIMEOUT: "The handshake operation timed out", ERR_PORT_UNREACHABLE: "The peer address is not reachable", } diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index 3976a62..d2323b5 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -26,14 +26,17 @@ library's ssl module, since its values can be passed to this module. import errno import socket import hmac +import datetime from logging import getLogger from os import urandom +from select import select from weakref import proxy from err import openssl_error, InvalidSocketError from err import raise_ssl_error from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE +from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT from x509 import _X509, decode_cert from tlock import tlock_init from openssl import * @@ -214,9 +217,35 @@ class SSLConnection(object): return lambda: self.do_handshake() def _check_nbio(self): - BIO_set_nbio(self._wbio.value, self._sock.gettimeout() is not None) + timeout = self._sock.gettimeout() + BIO_set_nbio(self._wbio.value, timeout is not None) if self._wbio is not self._rbio: - BIO_set_nbio(self._rbio.value, self._rsock.gettimeout() is not None) + timeout = self._rsock.gettimeout() + BIO_set_nbio(self._rbio.value, timeout is not None) + return timeout # read channel timeout + + def _wrap_socket_library_call(self, call, timeout_error): + timeout_sec_start = timeout_sec = self._check_nbio() + # Pass the call if the socket is blocking or non-blocking + if not timeout_sec: # None (blocking) or zero (non-blocking) + return call() + start_time = datetime.datetime.now() + read_sock = self.get_socket(True) + need_select = False + while timeout_sec > 0: + if need_select: + if not select([read_sock], [], [], timeout_sec)[0]: + break + timeout_sec = timeout_sec_start - \ + (datetime.datetime.now() - start_time).total_seconds() + try: + return call() + except openssl_error() as err: + if err.ssl_error == SSL_ERROR_WANT_READ: + need_select = True + continue + raise + raise_ssl_error(timeout_error) def _get_cookie(self, ssl): assert self._listening @@ -426,14 +455,12 @@ class SSLConnection(object): """ _logger.debug("Initiating handshake...") - self._check_nbio() try: - SSL_do_handshake(self._ssl.value) + self._wrap_socket_library_call( + lambda: SSL_do_handshake(self._ssl.value), + ERR_HANDSHAKE_TIMEOUT) except openssl_error() as err: - if err.ssl_error == SSL_ERROR_WANT_READ and \ - self.get_socket(True).gettimeout(): - raise_ssl_error(ERR_HANDSHAKE_TIMEOUT, err) - elif err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: + if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: raise_ssl_error(ERR_PORT_UNREACHABLE, err) raise self._handshake_done = True @@ -450,8 +477,8 @@ class SSLConnection(object): string containing read bytes """ - self._check_nbio() - return SSL_read(self._ssl.value, len) + return self._wrap_socket_library_call( + lambda: SSL_read(self._ssl.value, len), ERR_READ_TIMEOUT) def write(self, data): """Write data to connection @@ -465,8 +492,8 @@ class SSLConnection(object): number of bytes actually transmitted """ - self._check_nbio() - return SSL_write(self._ssl.value, data) + return self._wrap_socket_library_call( + lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT) def shutdown(self): """Shut down the DTLS connection @@ -480,9 +507,9 @@ class SSLConnection(object): # Listening server-side sockets cannot be shut down return - self._check_nbio() try: - SSL_shutdown(self._ssl.value) + self._wrap_socket_library_call( + lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT) except openssl_error() as err: if err.result == 0: # close-notify alert was just sent; wait for same from peer @@ -490,7 +517,8 @@ class SSLConnection(object): # with SSL_set_read_ahead here, doing so causes a shutdown # failure (ret: -1, SSL_ERROR_SYSCALL) on the DTLS shutdown # initiator side. And test_starttls does pass. - SSL_shutdown(self._ssl.value) + self._wrap_socket_library_call( + lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT) else: raise if hasattr(self, "_udp_demux"): diff --git a/dtls/test/echo_seq.py b/dtls/test/echo_seq.py index 7a3bf3b..1bd91f9 100644 --- a/dtls/test/echo_seq.py +++ b/dtls/test/echo_seq.py @@ -59,7 +59,7 @@ def main(): try: conn.do_handshake() except SSLError as err: - if len(err.args) > 1 and err.args[1].args[0] == SSL_ERROR_WANT_READ: + if str(err).startswith("504:"): continue raise print "Completed handshaking with peer" @@ -75,7 +75,7 @@ def main(): try: message = conn.read() except SSLError as err: - if err.args[0] == SSL_ERROR_WANT_READ: + if str(err).startswith("502:"): continue if err.args[0] == SSL_ERROR_ZERO_RETURN: break @@ -91,9 +91,10 @@ def main(): assert not peer_address print "Shutdown invocation: %d" % cnt try: - conn.shutdown() + s = conn.shutdown() + s.shutdown(socket.SHUT_RDWR) except SSLError as err: - if err.args[0] == SSL_ERROR_WANT_READ: + if str(err).startswith("502:"): continue raise break diff --git a/dtls/test/unit.py b/dtls/test/unit.py index 17f6cd4..b585889 100644 --- a/dtls/test/unit.py +++ b/dtls/test/unit.py @@ -15,14 +15,17 @@ import traceback import weakref import platform import threading +import time import datetime import SocketServer from SimpleHTTPServer import SimpleHTTPRequestHandler +from collections import OrderedDict import ssl from dtls import do_patch HOST = "localhost" +CONNECTION_TIMEOUT = datetime.timedelta(seconds=30) class TestSupport(object): verbose = True @@ -37,7 +40,6 @@ class TestSupport(object): def __exit__(self, exc_type, exc_value, traceback): self.server.stop() - self.server.join() self.server = None def transient_internet(self): @@ -134,7 +136,6 @@ class BasicSocketTests(unittest.TestCase): s.connect(remote) finally: server.stop() - server.join() @unittest.skipIf(platform.python_implementation() != "CPython", "Reference cycle test feasible under CPython only") @@ -331,9 +332,10 @@ class ThreadedEchoServer(threading.Thread): self.server = server self.running = False self.sock = connsock - self.sock.setblocking(True) + self.sock.settimeout(CONNECTION_TIMEOUT.total_seconds()) self.sslconn = connsock threading.Thread.__init__(self) + server.register_handler(True) self.daemon = True def show_conn_details(self): @@ -389,6 +391,7 @@ class ThreadedEchoServer(threading.Thread): return self.sock.send(bytes) def close(self): + self.server.register_handler(False) if self.sslconn: self.sslconn.close() else: @@ -486,6 +489,8 @@ class ThreadedEchoServer(threading.Thread): self.starttls_server = starttls_server self.sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM) self.flag = None + self.num_handlers = 0 + self.num_handlers_lock = threading.Lock() self.sock = ssl.wrap_socket(self.sock, server_side=True, certfile=self.certificate, cert_reqs=self.certreqs, @@ -504,6 +509,7 @@ class ThreadedEchoServer(threading.Thread): def start(self, flag=None): self.flag = flag + self.starter = threading.current_thread().ident threading.Thread.start(self) def run(self): @@ -529,8 +535,26 @@ class ThreadedEchoServer(threading.Thread): self.stop() self.sock.close() + def register_handler(self, add): + with self.num_handlers_lock: + if add: + self.num_handlers += 1 + else: + self.num_handlers -= 1 + assert self.num_handlers >= 0 + def stop(self): self.active = False + if self.starter != threading.current_thread().ident: + return + self.join() # don't allow spawning new handlers after we've checked + last_msg = datetime.datetime.now() + while self.num_handlers: + time.sleep(0.05) + now = datetime.datetime.now() + if now > last_msg + datetime.timedelta(seconds=1): + sys.stdout.write(' server: waiting for connections to close\n') + last_msg = now class AsyncoreEchoServer(threading.Thread): @@ -538,9 +562,10 @@ class AsyncoreEchoServer(threading.Thread): class ConnectionHandler(asyncore.dispatcher): - def __init__(self, conn, timeout_tracker): + def __init__(self, conn, timeout_tracker, server): asyncore.dispatcher.__init__(self, conn) self._timeout_tracker = timeout_tracker + self._server = server self._ssl_accepting = True # Complete the handshake self.handle_read_event() @@ -585,6 +610,11 @@ class AsyncoreEchoServer(threading.Thread): data = self.recv(1024) if data and data.strip() != 'over': self.send(data.lower()) + if self.connected: + self._server.reset_timeout(self) + self._server.check_timeout() + if not self.connected: # above called handle_close + return delta = self.socket.get_timeout() if delta: self._timeout_tracker[self] = \ @@ -593,6 +623,7 @@ class AsyncoreEchoServer(threading.Thread): def handle_close(self): if self._timeout_tracker.has_key(self): self._timeout_tracker.pop(self) + self._server._handlers.pop(self) self.close() if test_support.verbose: sys.stdout.write(" server: closed connection %s\n" % @@ -604,6 +635,7 @@ class AsyncoreEchoServer(threading.Thread): def __init__(self, certfile, timeout_tracker): asyncore.dispatcher.__init__(self) self._timeout_tracker = timeout_tracker + self._handlers = OrderedDict() sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM) sock.setblocking(False) sock.bind((HOST, 0)) @@ -618,17 +650,43 @@ class AsyncoreEchoServer(threading.Thread): return False def handle_accept(self): + self.check_timeout() acc_ret = self.accept() if acc_ret: sock_obj, addr = acc_ret if test_support.verbose: sys.stdout.write(" server: new connection from " + "%s:%s\n" % (addr[0], str(addr[1:]))) - self.ConnectionHandler(sock_obj, self._timeout_tracker) + self._handlers[self.ConnectionHandler(sock_obj, + self._timeout_tracker, + self)] = \ + datetime.datetime.now() def handle_error(self): raise + def reset_timeout(self, handler): + if self._handlers.has_key(handler): + self._handlers.pop(handler) + self._handlers[handler] = datetime.datetime.now() + + def check_timeout(self): + now = datetime.datetime.now() + while True: + try: + handler = self._handlers.__iter__().next() # oldest handler + except StopIteration: + break # there are no more handlers + if now > self._handlers[handler] + CONNECTION_TIMEOUT: + handler.handle_close() + else: + break # the oldest handlers has not yet timed out + + def close(self): + map(lambda x: x.handle_close(), self._handlers.keys()) + assert not self._handlers + asyncore.dispatcher.close(self) + def __init__(self, certfile): self.flag = None self.active = False @@ -660,6 +718,7 @@ class AsyncoreEchoServer(threading.Thread): def stop(self): self.active = False + self.join() self.server.close() # Note that this HTTP-over-UDP server does not implement packet recovery and @@ -800,7 +859,6 @@ def bad_cert_test(certfile): raise AssertionError("Use of invalid cert should have failed!") finally: server.stop() - server.join() def server_params_test(certfile, protocol, certreqs, cacertsfile, client_certfile, client_protocol=None, @@ -854,7 +912,6 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile, s.close() finally: server.stop() - server.join() def try_protocol_combo(server_protocol, client_protocol, @@ -950,10 +1007,10 @@ class ThreadedTests(unittest.TestCase): self.fail( "Missing or invalid 'organizationName' field in " "certificate subject; should be 'Ray Srv Inc'.") + s.write("over\n") s.close() finally: server.stop() - server.join() def test_empty_cert(self): """Connecting with an empty cert file""" @@ -1045,7 +1102,6 @@ class ThreadedTests(unittest.TestCase): s.close() finally: server.stop() - server.join() def test_socketserver(self): """Using a SocketServer to create and manage SSL connections.""" @@ -1095,7 +1151,6 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(d1, ''.join(d2)) finally: server.stop() - server.join() def test_asyncore_server(self): """Check the example asyncore integration.""" @@ -1130,8 +1185,6 @@ class ThreadedTests(unittest.TestCase): s.close() finally: server.stop() - # wait for server thread to end - server.join() def test_recv_send(self): """Test recv(), send() and friends.""" @@ -1246,7 +1299,6 @@ class ThreadedTests(unittest.TestCase): s.close() finally: server.stop() - server.join() def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout