diff --git a/dtls/err.py b/dtls/err.py index 9dbe563..6c5fc40 100644 --- a/dtls/err.py +++ b/dtls/err.py @@ -26,8 +26,10 @@ ERR_BOTH_KEY_CERT_FILES = 500 ERR_BOTH_KEY_CERT_FILES_SVR = 298 ERR_NO_CERTS = 331 ERR_NO_CIPHER = 501 -ERR_HANDSHAKE_TIMEOUT = 502 -ERR_PORT_UNREACHABLE = 503 +ERR_READ_TIMEOUT = 502 +ERR_WRITE_TIMEOUT = 503 +ERR_HANDSHAKE_TIMEOUT = 504 +ERR_PORT_UNREACHABLE = 505 ERR_COOKIE_MISMATCH = 0x1408A134 @@ -88,6 +90,8 @@ _ssl_errors = { ERR_BOTH_KEY_CERT_FILES_SVR: "Both the key & certificate files must be " + \ "specified for server-side operation", ERR_NO_CIPHER: "No cipher can be selected.", + ERR_READ_TIMEOUT: "The read operation timed out", + ERR_WRITE_TIMEOUT: "The write operation timed out", ERR_HANDSHAKE_TIMEOUT: "The handshake operation timed out", ERR_PORT_UNREACHABLE: "The peer address is not reachable", } diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index 3976a62..d2323b5 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -26,14 +26,17 @@ library's ssl module, since its values can be passed to this module. import errno import socket import hmac +import datetime from logging import getLogger from os import urandom +from select import select from weakref import proxy from err import openssl_error, InvalidSocketError from err import raise_ssl_error from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE +from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT from x509 import _X509, decode_cert from tlock import tlock_init from openssl import * @@ -214,9 +217,35 @@ class SSLConnection(object): return lambda: self.do_handshake() def _check_nbio(self): - BIO_set_nbio(self._wbio.value, self._sock.gettimeout() is not None) + timeout = self._sock.gettimeout() + BIO_set_nbio(self._wbio.value, timeout is not None) if self._wbio is not self._rbio: - BIO_set_nbio(self._rbio.value, self._rsock.gettimeout() is not None) + timeout = self._rsock.gettimeout() + BIO_set_nbio(self._rbio.value, timeout is not None) + return timeout # read channel timeout + + def _wrap_socket_library_call(self, call, timeout_error): + timeout_sec_start = timeout_sec = self._check_nbio() + # Pass the call if the socket is blocking or non-blocking + if not timeout_sec: # None (blocking) or zero (non-blocking) + return call() + start_time = datetime.datetime.now() + read_sock = self.get_socket(True) + need_select = False + while timeout_sec > 0: + if need_select: + if not select([read_sock], [], [], timeout_sec)[0]: + break + timeout_sec = timeout_sec_start - \ + (datetime.datetime.now() - start_time).total_seconds() + try: + return call() + except openssl_error() as err: + if err.ssl_error == SSL_ERROR_WANT_READ: + need_select = True + continue + raise + raise_ssl_error(timeout_error) def _get_cookie(self, ssl): assert self._listening @@ -426,14 +455,12 @@ class SSLConnection(object): """ _logger.debug("Initiating handshake...") - self._check_nbio() try: - SSL_do_handshake(self._ssl.value) + self._wrap_socket_library_call( + lambda: SSL_do_handshake(self._ssl.value), + ERR_HANDSHAKE_TIMEOUT) except openssl_error() as err: - if err.ssl_error == SSL_ERROR_WANT_READ and \ - self.get_socket(True).gettimeout(): - raise_ssl_error(ERR_HANDSHAKE_TIMEOUT, err) - elif err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: + if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: raise_ssl_error(ERR_PORT_UNREACHABLE, err) raise self._handshake_done = True @@ -450,8 +477,8 @@ class SSLConnection(object): string containing read bytes """ - self._check_nbio() - return SSL_read(self._ssl.value, len) + return self._wrap_socket_library_call( + lambda: SSL_read(self._ssl.value, len), ERR_READ_TIMEOUT) def write(self, data): """Write data to connection @@ -465,8 +492,8 @@ class SSLConnection(object): number of bytes actually transmitted """ - self._check_nbio() - return SSL_write(self._ssl.value, data) + return self._wrap_socket_library_call( + lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT) def shutdown(self): """Shut down the DTLS connection @@ -480,9 +507,9 @@ class SSLConnection(object): # Listening server-side sockets cannot be shut down return - self._check_nbio() try: - SSL_shutdown(self._ssl.value) + self._wrap_socket_library_call( + lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT) except openssl_error() as err: if err.result == 0: # close-notify alert was just sent; wait for same from peer @@ -490,7 +517,8 @@ class SSLConnection(object): # with SSL_set_read_ahead here, doing so causes a shutdown # failure (ret: -1, SSL_ERROR_SYSCALL) on the DTLS shutdown # initiator side. And test_starttls does pass. - SSL_shutdown(self._ssl.value) + self._wrap_socket_library_call( + lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT) else: raise if hasattr(self, "_udp_demux"): diff --git a/dtls/test/echo_seq.py b/dtls/test/echo_seq.py index 7a3bf3b..1bd91f9 100644 --- a/dtls/test/echo_seq.py +++ b/dtls/test/echo_seq.py @@ -59,7 +59,7 @@ def main(): try: conn.do_handshake() except SSLError as err: - if len(err.args) > 1 and err.args[1].args[0] == SSL_ERROR_WANT_READ: + if str(err).startswith("504:"): continue raise print "Completed handshaking with peer" @@ -75,7 +75,7 @@ def main(): try: message = conn.read() except SSLError as err: - if err.args[0] == SSL_ERROR_WANT_READ: + if str(err).startswith("502:"): continue if err.args[0] == SSL_ERROR_ZERO_RETURN: break @@ -91,9 +91,10 @@ def main(): assert not peer_address print "Shutdown invocation: %d" % cnt try: - conn.shutdown() + s = conn.shutdown() + s.shutdown(socket.SHUT_RDWR) except SSLError as err: - if err.args[0] == SSL_ERROR_WANT_READ: + if str(err).startswith("502:"): continue raise break diff --git a/dtls/test/unit.py b/dtls/test/unit.py index 17f6cd4..b585889 100644 --- a/dtls/test/unit.py +++ b/dtls/test/unit.py @@ -15,14 +15,17 @@ import traceback import weakref import platform import threading +import time import datetime import SocketServer from SimpleHTTPServer import SimpleHTTPRequestHandler +from collections import OrderedDict import ssl from dtls import do_patch HOST = "localhost" +CONNECTION_TIMEOUT = datetime.timedelta(seconds=30) class TestSupport(object): verbose = True @@ -37,7 +40,6 @@ class TestSupport(object): def __exit__(self, exc_type, exc_value, traceback): self.server.stop() - self.server.join() self.server = None def transient_internet(self): @@ -134,7 +136,6 @@ class BasicSocketTests(unittest.TestCase): s.connect(remote) finally: server.stop() - server.join() @unittest.skipIf(platform.python_implementation() != "CPython", "Reference cycle test feasible under CPython only") @@ -331,9 +332,10 @@ class ThreadedEchoServer(threading.Thread): self.server = server self.running = False self.sock = connsock - self.sock.setblocking(True) + self.sock.settimeout(CONNECTION_TIMEOUT.total_seconds()) self.sslconn = connsock threading.Thread.__init__(self) + server.register_handler(True) self.daemon = True def show_conn_details(self): @@ -389,6 +391,7 @@ class ThreadedEchoServer(threading.Thread): return self.sock.send(bytes) def close(self): + self.server.register_handler(False) if self.sslconn: self.sslconn.close() else: @@ -486,6 +489,8 @@ class ThreadedEchoServer(threading.Thread): self.starttls_server = starttls_server self.sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM) self.flag = None + self.num_handlers = 0 + self.num_handlers_lock = threading.Lock() self.sock = ssl.wrap_socket(self.sock, server_side=True, certfile=self.certificate, cert_reqs=self.certreqs, @@ -504,6 +509,7 @@ class ThreadedEchoServer(threading.Thread): def start(self, flag=None): self.flag = flag + self.starter = threading.current_thread().ident threading.Thread.start(self) def run(self): @@ -529,8 +535,26 @@ class ThreadedEchoServer(threading.Thread): self.stop() self.sock.close() + def register_handler(self, add): + with self.num_handlers_lock: + if add: + self.num_handlers += 1 + else: + self.num_handlers -= 1 + assert self.num_handlers >= 0 + def stop(self): self.active = False + if self.starter != threading.current_thread().ident: + return + self.join() # don't allow spawning new handlers after we've checked + last_msg = datetime.datetime.now() + while self.num_handlers: + time.sleep(0.05) + now = datetime.datetime.now() + if now > last_msg + datetime.timedelta(seconds=1): + sys.stdout.write(' server: waiting for connections to close\n') + last_msg = now class AsyncoreEchoServer(threading.Thread): @@ -538,9 +562,10 @@ class AsyncoreEchoServer(threading.Thread): class ConnectionHandler(asyncore.dispatcher): - def __init__(self, conn, timeout_tracker): + def __init__(self, conn, timeout_tracker, server): asyncore.dispatcher.__init__(self, conn) self._timeout_tracker = timeout_tracker + self._server = server self._ssl_accepting = True # Complete the handshake self.handle_read_event() @@ -585,6 +610,11 @@ class AsyncoreEchoServer(threading.Thread): data = self.recv(1024) if data and data.strip() != 'over': self.send(data.lower()) + if self.connected: + self._server.reset_timeout(self) + self._server.check_timeout() + if not self.connected: # above called handle_close + return delta = self.socket.get_timeout() if delta: self._timeout_tracker[self] = \ @@ -593,6 +623,7 @@ class AsyncoreEchoServer(threading.Thread): def handle_close(self): if self._timeout_tracker.has_key(self): self._timeout_tracker.pop(self) + self._server._handlers.pop(self) self.close() if test_support.verbose: sys.stdout.write(" server: closed connection %s\n" % @@ -604,6 +635,7 @@ class AsyncoreEchoServer(threading.Thread): def __init__(self, certfile, timeout_tracker): asyncore.dispatcher.__init__(self) self._timeout_tracker = timeout_tracker + self._handlers = OrderedDict() sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM) sock.setblocking(False) sock.bind((HOST, 0)) @@ -618,17 +650,43 @@ class AsyncoreEchoServer(threading.Thread): return False def handle_accept(self): + self.check_timeout() acc_ret = self.accept() if acc_ret: sock_obj, addr = acc_ret if test_support.verbose: sys.stdout.write(" server: new connection from " + "%s:%s\n" % (addr[0], str(addr[1:]))) - self.ConnectionHandler(sock_obj, self._timeout_tracker) + self._handlers[self.ConnectionHandler(sock_obj, + self._timeout_tracker, + self)] = \ + datetime.datetime.now() def handle_error(self): raise + def reset_timeout(self, handler): + if self._handlers.has_key(handler): + self._handlers.pop(handler) + self._handlers[handler] = datetime.datetime.now() + + def check_timeout(self): + now = datetime.datetime.now() + while True: + try: + handler = self._handlers.__iter__().next() # oldest handler + except StopIteration: + break # there are no more handlers + if now > self._handlers[handler] + CONNECTION_TIMEOUT: + handler.handle_close() + else: + break # the oldest handlers has not yet timed out + + def close(self): + map(lambda x: x.handle_close(), self._handlers.keys()) + assert not self._handlers + asyncore.dispatcher.close(self) + def __init__(self, certfile): self.flag = None self.active = False @@ -660,6 +718,7 @@ class AsyncoreEchoServer(threading.Thread): def stop(self): self.active = False + self.join() self.server.close() # Note that this HTTP-over-UDP server does not implement packet recovery and @@ -800,7 +859,6 @@ def bad_cert_test(certfile): raise AssertionError("Use of invalid cert should have failed!") finally: server.stop() - server.join() def server_params_test(certfile, protocol, certreqs, cacertsfile, client_certfile, client_protocol=None, @@ -854,7 +912,6 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile, s.close() finally: server.stop() - server.join() def try_protocol_combo(server_protocol, client_protocol, @@ -950,10 +1007,10 @@ class ThreadedTests(unittest.TestCase): self.fail( "Missing or invalid 'organizationName' field in " "certificate subject; should be 'Ray Srv Inc'.") + s.write("over\n") s.close() finally: server.stop() - server.join() def test_empty_cert(self): """Connecting with an empty cert file""" @@ -1045,7 +1102,6 @@ class ThreadedTests(unittest.TestCase): s.close() finally: server.stop() - server.join() def test_socketserver(self): """Using a SocketServer to create and manage SSL connections.""" @@ -1095,7 +1151,6 @@ class ThreadedTests(unittest.TestCase): self.assertEqual(d1, ''.join(d2)) finally: server.stop() - server.join() def test_asyncore_server(self): """Check the example asyncore integration.""" @@ -1130,8 +1185,6 @@ class ThreadedTests(unittest.TestCase): s.close() finally: server.stop() - # wait for server thread to end - server.join() def test_recv_send(self): """Test recv(), send() and friends.""" @@ -1246,7 +1299,6 @@ class ThreadedTests(unittest.TestCase): s.close() finally: server.stop() - server.join() def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout