Echo server connection timeout and monitoring

Prior to this change, the unit tests' echo servers' connections would sometimes
linger beyond server termination. A timeout mechanism is now implemented, which
will terminate a connection and clean up its resources when the timeout is
reached. For the purpose of unit testing, test echo servers now assert that
all connections have been terminated when the servers are closed.

For threaded echo servers, the timeout mechanism involves the use of sockets
with timeouts instead of blocking sockets. This required an implementation with
the proper handling of timeout sockets at the sslconnection level.
incoming
Ray Brown 2012-11-25 13:44:55 -08:00
parent 821952b669
commit 9bee9d049f
4 changed files with 119 additions and 34 deletions

View File

@ -26,8 +26,10 @@ ERR_BOTH_KEY_CERT_FILES = 500
ERR_BOTH_KEY_CERT_FILES_SVR = 298 ERR_BOTH_KEY_CERT_FILES_SVR = 298
ERR_NO_CERTS = 331 ERR_NO_CERTS = 331
ERR_NO_CIPHER = 501 ERR_NO_CIPHER = 501
ERR_HANDSHAKE_TIMEOUT = 502 ERR_READ_TIMEOUT = 502
ERR_PORT_UNREACHABLE = 503 ERR_WRITE_TIMEOUT = 503
ERR_HANDSHAKE_TIMEOUT = 504
ERR_PORT_UNREACHABLE = 505
ERR_COOKIE_MISMATCH = 0x1408A134 ERR_COOKIE_MISMATCH = 0x1408A134
@ -88,6 +90,8 @@ _ssl_errors = {
ERR_BOTH_KEY_CERT_FILES_SVR: "Both the key & certificate files must be " + \ ERR_BOTH_KEY_CERT_FILES_SVR: "Both the key & certificate files must be " + \
"specified for server-side operation", "specified for server-side operation",
ERR_NO_CIPHER: "No cipher can be selected.", ERR_NO_CIPHER: "No cipher can be selected.",
ERR_READ_TIMEOUT: "The read operation timed out",
ERR_WRITE_TIMEOUT: "The write operation timed out",
ERR_HANDSHAKE_TIMEOUT: "The handshake operation timed out", ERR_HANDSHAKE_TIMEOUT: "The handshake operation timed out",
ERR_PORT_UNREACHABLE: "The peer address is not reachable", ERR_PORT_UNREACHABLE: "The peer address is not reachable",
} }

View File

@ -26,14 +26,17 @@ library's ssl module, since its values can be passed to this module.
import errno import errno
import socket import socket
import hmac import hmac
import datetime
from logging import getLogger from logging import getLogger
from os import urandom from os import urandom
from select import select
from weakref import proxy from weakref import proxy
from err import openssl_error, InvalidSocketError from err import openssl_error, InvalidSocketError
from err import raise_ssl_error from err import raise_ssl_error
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from x509 import _X509, decode_cert from x509 import _X509, decode_cert
from tlock import tlock_init from tlock import tlock_init
from openssl import * from openssl import *
@ -214,9 +217,35 @@ class SSLConnection(object):
return lambda: self.do_handshake() return lambda: self.do_handshake()
def _check_nbio(self): def _check_nbio(self):
BIO_set_nbio(self._wbio.value, self._sock.gettimeout() is not None) timeout = self._sock.gettimeout()
BIO_set_nbio(self._wbio.value, timeout is not None)
if self._wbio is not self._rbio: if self._wbio is not self._rbio:
BIO_set_nbio(self._rbio.value, self._rsock.gettimeout() is not None) timeout = self._rsock.gettimeout()
BIO_set_nbio(self._rbio.value, timeout is not None)
return timeout # read channel timeout
def _wrap_socket_library_call(self, call, timeout_error):
timeout_sec_start = timeout_sec = self._check_nbio()
# Pass the call if the socket is blocking or non-blocking
if not timeout_sec: # None (blocking) or zero (non-blocking)
return call()
start_time = datetime.datetime.now()
read_sock = self.get_socket(True)
need_select = False
while timeout_sec > 0:
if need_select:
if not select([read_sock], [], [], timeout_sec)[0]:
break
timeout_sec = timeout_sec_start - \
(datetime.datetime.now() - start_time).total_seconds()
try:
return call()
except openssl_error() as err:
if err.ssl_error == SSL_ERROR_WANT_READ:
need_select = True
continue
raise
raise_ssl_error(timeout_error)
def _get_cookie(self, ssl): def _get_cookie(self, ssl):
assert self._listening assert self._listening
@ -426,14 +455,12 @@ class SSLConnection(object):
""" """
_logger.debug("Initiating handshake...") _logger.debug("Initiating handshake...")
self._check_nbio()
try: try:
SSL_do_handshake(self._ssl.value) self._wrap_socket_library_call(
lambda: SSL_do_handshake(self._ssl.value),
ERR_HANDSHAKE_TIMEOUT)
except openssl_error() as err: except openssl_error() as err:
if err.ssl_error == SSL_ERROR_WANT_READ and \ if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
self.get_socket(True).gettimeout():
raise_ssl_error(ERR_HANDSHAKE_TIMEOUT, err)
elif err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err) raise_ssl_error(ERR_PORT_UNREACHABLE, err)
raise raise
self._handshake_done = True self._handshake_done = True
@ -450,8 +477,8 @@ class SSLConnection(object):
string containing read bytes string containing read bytes
""" """
self._check_nbio() return self._wrap_socket_library_call(
return SSL_read(self._ssl.value, len) lambda: SSL_read(self._ssl.value, len), ERR_READ_TIMEOUT)
def write(self, data): def write(self, data):
"""Write data to connection """Write data to connection
@ -465,8 +492,8 @@ class SSLConnection(object):
number of bytes actually transmitted number of bytes actually transmitted
""" """
self._check_nbio() return self._wrap_socket_library_call(
return SSL_write(self._ssl.value, data) lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT)
def shutdown(self): def shutdown(self):
"""Shut down the DTLS connection """Shut down the DTLS connection
@ -480,9 +507,9 @@ class SSLConnection(object):
# Listening server-side sockets cannot be shut down # Listening server-side sockets cannot be shut down
return return
self._check_nbio()
try: try:
SSL_shutdown(self._ssl.value) self._wrap_socket_library_call(
lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT)
except openssl_error() as err: except openssl_error() as err:
if err.result == 0: if err.result == 0:
# close-notify alert was just sent; wait for same from peer # close-notify alert was just sent; wait for same from peer
@ -490,7 +517,8 @@ class SSLConnection(object):
# with SSL_set_read_ahead here, doing so causes a shutdown # with SSL_set_read_ahead here, doing so causes a shutdown
# failure (ret: -1, SSL_ERROR_SYSCALL) on the DTLS shutdown # failure (ret: -1, SSL_ERROR_SYSCALL) on the DTLS shutdown
# initiator side. And test_starttls does pass. # initiator side. And test_starttls does pass.
SSL_shutdown(self._ssl.value) self._wrap_socket_library_call(
lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT)
else: else:
raise raise
if hasattr(self, "_udp_demux"): if hasattr(self, "_udp_demux"):

View File

@ -59,7 +59,7 @@ def main():
try: try:
conn.do_handshake() conn.do_handshake()
except SSLError as err: except SSLError as err:
if len(err.args) > 1 and err.args[1].args[0] == SSL_ERROR_WANT_READ: if str(err).startswith("504:"):
continue continue
raise raise
print "Completed handshaking with peer" print "Completed handshaking with peer"
@ -75,7 +75,7 @@ def main():
try: try:
message = conn.read() message = conn.read()
except SSLError as err: except SSLError as err:
if err.args[0] == SSL_ERROR_WANT_READ: if str(err).startswith("502:"):
continue continue
if err.args[0] == SSL_ERROR_ZERO_RETURN: if err.args[0] == SSL_ERROR_ZERO_RETURN:
break break
@ -91,9 +91,10 @@ def main():
assert not peer_address assert not peer_address
print "Shutdown invocation: %d" % cnt print "Shutdown invocation: %d" % cnt
try: try:
conn.shutdown() s = conn.shutdown()
s.shutdown(socket.SHUT_RDWR)
except SSLError as err: except SSLError as err:
if err.args[0] == SSL_ERROR_WANT_READ: if str(err).startswith("502:"):
continue continue
raise raise
break break

View File

@ -15,14 +15,17 @@ import traceback
import weakref import weakref
import platform import platform
import threading import threading
import time
import datetime import datetime
import SocketServer import SocketServer
from SimpleHTTPServer import SimpleHTTPRequestHandler from SimpleHTTPServer import SimpleHTTPRequestHandler
from collections import OrderedDict
import ssl import ssl
from dtls import do_patch from dtls import do_patch
HOST = "localhost" HOST = "localhost"
CONNECTION_TIMEOUT = datetime.timedelta(seconds=30)
class TestSupport(object): class TestSupport(object):
verbose = True verbose = True
@ -37,7 +40,6 @@ class TestSupport(object):
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.server.stop() self.server.stop()
self.server.join()
self.server = None self.server = None
def transient_internet(self): def transient_internet(self):
@ -134,7 +136,6 @@ class BasicSocketTests(unittest.TestCase):
s.connect(remote) s.connect(remote)
finally: finally:
server.stop() server.stop()
server.join()
@unittest.skipIf(platform.python_implementation() != "CPython", @unittest.skipIf(platform.python_implementation() != "CPython",
"Reference cycle test feasible under CPython only") "Reference cycle test feasible under CPython only")
@ -331,9 +332,10 @@ class ThreadedEchoServer(threading.Thread):
self.server = server self.server = server
self.running = False self.running = False
self.sock = connsock self.sock = connsock
self.sock.setblocking(True) self.sock.settimeout(CONNECTION_TIMEOUT.total_seconds())
self.sslconn = connsock self.sslconn = connsock
threading.Thread.__init__(self) threading.Thread.__init__(self)
server.register_handler(True)
self.daemon = True self.daemon = True
def show_conn_details(self): def show_conn_details(self):
@ -389,6 +391,7 @@ class ThreadedEchoServer(threading.Thread):
return self.sock.send(bytes) return self.sock.send(bytes)
def close(self): def close(self):
self.server.register_handler(False)
if self.sslconn: if self.sslconn:
self.sslconn.close() self.sslconn.close()
else: else:
@ -486,6 +489,8 @@ class ThreadedEchoServer(threading.Thread):
self.starttls_server = starttls_server self.starttls_server = starttls_server
self.sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM) self.sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
self.flag = None self.flag = None
self.num_handlers = 0
self.num_handlers_lock = threading.Lock()
self.sock = ssl.wrap_socket(self.sock, server_side=True, self.sock = ssl.wrap_socket(self.sock, server_side=True,
certfile=self.certificate, certfile=self.certificate,
cert_reqs=self.certreqs, cert_reqs=self.certreqs,
@ -504,6 +509,7 @@ class ThreadedEchoServer(threading.Thread):
def start(self, flag=None): def start(self, flag=None):
self.flag = flag self.flag = flag
self.starter = threading.current_thread().ident
threading.Thread.start(self) threading.Thread.start(self)
def run(self): def run(self):
@ -529,8 +535,26 @@ class ThreadedEchoServer(threading.Thread):
self.stop() self.stop()
self.sock.close() self.sock.close()
def register_handler(self, add):
with self.num_handlers_lock:
if add:
self.num_handlers += 1
else:
self.num_handlers -= 1
assert self.num_handlers >= 0
def stop(self): def stop(self):
self.active = False self.active = False
if self.starter != threading.current_thread().ident:
return
self.join() # don't allow spawning new handlers after we've checked
last_msg = datetime.datetime.now()
while self.num_handlers:
time.sleep(0.05)
now = datetime.datetime.now()
if now > last_msg + datetime.timedelta(seconds=1):
sys.stdout.write(' server: waiting for connections to close\n')
last_msg = now
class AsyncoreEchoServer(threading.Thread): class AsyncoreEchoServer(threading.Thread):
@ -538,9 +562,10 @@ class AsyncoreEchoServer(threading.Thread):
class ConnectionHandler(asyncore.dispatcher): class ConnectionHandler(asyncore.dispatcher):
def __init__(self, conn, timeout_tracker): def __init__(self, conn, timeout_tracker, server):
asyncore.dispatcher.__init__(self, conn) asyncore.dispatcher.__init__(self, conn)
self._timeout_tracker = timeout_tracker self._timeout_tracker = timeout_tracker
self._server = server
self._ssl_accepting = True self._ssl_accepting = True
# Complete the handshake # Complete the handshake
self.handle_read_event() self.handle_read_event()
@ -585,6 +610,11 @@ class AsyncoreEchoServer(threading.Thread):
data = self.recv(1024) data = self.recv(1024)
if data and data.strip() != 'over': if data and data.strip() != 'over':
self.send(data.lower()) self.send(data.lower())
if self.connected:
self._server.reset_timeout(self)
self._server.check_timeout()
if not self.connected: # above called handle_close
return
delta = self.socket.get_timeout() delta = self.socket.get_timeout()
if delta: if delta:
self._timeout_tracker[self] = \ self._timeout_tracker[self] = \
@ -593,6 +623,7 @@ class AsyncoreEchoServer(threading.Thread):
def handle_close(self): def handle_close(self):
if self._timeout_tracker.has_key(self): if self._timeout_tracker.has_key(self):
self._timeout_tracker.pop(self) self._timeout_tracker.pop(self)
self._server._handlers.pop(self)
self.close() self.close()
if test_support.verbose: if test_support.verbose:
sys.stdout.write(" server: closed connection %s\n" % sys.stdout.write(" server: closed connection %s\n" %
@ -604,6 +635,7 @@ class AsyncoreEchoServer(threading.Thread):
def __init__(self, certfile, timeout_tracker): def __init__(self, certfile, timeout_tracker):
asyncore.dispatcher.__init__(self) asyncore.dispatcher.__init__(self)
self._timeout_tracker = timeout_tracker self._timeout_tracker = timeout_tracker
self._handlers = OrderedDict()
sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM) sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
sock.setblocking(False) sock.setblocking(False)
sock.bind((HOST, 0)) sock.bind((HOST, 0))
@ -618,17 +650,43 @@ class AsyncoreEchoServer(threading.Thread):
return False return False
def handle_accept(self): def handle_accept(self):
self.check_timeout()
acc_ret = self.accept() acc_ret = self.accept()
if acc_ret: if acc_ret:
sock_obj, addr = acc_ret sock_obj, addr = acc_ret
if test_support.verbose: if test_support.verbose:
sys.stdout.write(" server: new connection from " + sys.stdout.write(" server: new connection from " +
"%s:%s\n" % (addr[0], str(addr[1:]))) "%s:%s\n" % (addr[0], str(addr[1:])))
self.ConnectionHandler(sock_obj, self._timeout_tracker) self._handlers[self.ConnectionHandler(sock_obj,
self._timeout_tracker,
self)] = \
datetime.datetime.now()
def handle_error(self): def handle_error(self):
raise raise
def reset_timeout(self, handler):
if self._handlers.has_key(handler):
self._handlers.pop(handler)
self._handlers[handler] = datetime.datetime.now()
def check_timeout(self):
now = datetime.datetime.now()
while True:
try:
handler = self._handlers.__iter__().next() # oldest handler
except StopIteration:
break # there are no more handlers
if now > self._handlers[handler] + CONNECTION_TIMEOUT:
handler.handle_close()
else:
break # the oldest handlers has not yet timed out
def close(self):
map(lambda x: x.handle_close(), self._handlers.keys())
assert not self._handlers
asyncore.dispatcher.close(self)
def __init__(self, certfile): def __init__(self, certfile):
self.flag = None self.flag = None
self.active = False self.active = False
@ -660,6 +718,7 @@ class AsyncoreEchoServer(threading.Thread):
def stop(self): def stop(self):
self.active = False self.active = False
self.join()
self.server.close() self.server.close()
# Note that this HTTP-over-UDP server does not implement packet recovery and # Note that this HTTP-over-UDP server does not implement packet recovery and
@ -800,7 +859,6 @@ def bad_cert_test(certfile):
raise AssertionError("Use of invalid cert should have failed!") raise AssertionError("Use of invalid cert should have failed!")
finally: finally:
server.stop() server.stop()
server.join()
def server_params_test(certfile, protocol, certreqs, cacertsfile, def server_params_test(certfile, protocol, certreqs, cacertsfile,
client_certfile, client_protocol=None, client_certfile, client_protocol=None,
@ -854,7 +912,6 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
s.close() s.close()
finally: finally:
server.stop() server.stop()
server.join()
def try_protocol_combo(server_protocol, def try_protocol_combo(server_protocol,
client_protocol, client_protocol,
@ -950,10 +1007,10 @@ class ThreadedTests(unittest.TestCase):
self.fail( self.fail(
"Missing or invalid 'organizationName' field in " "Missing or invalid 'organizationName' field in "
"certificate subject; should be 'Ray Srv Inc'.") "certificate subject; should be 'Ray Srv Inc'.")
s.write("over\n")
s.close() s.close()
finally: finally:
server.stop() server.stop()
server.join()
def test_empty_cert(self): def test_empty_cert(self):
"""Connecting with an empty cert file""" """Connecting with an empty cert file"""
@ -1045,7 +1102,6 @@ class ThreadedTests(unittest.TestCase):
s.close() s.close()
finally: finally:
server.stop() server.stop()
server.join()
def test_socketserver(self): def test_socketserver(self):
"""Using a SocketServer to create and manage SSL connections.""" """Using a SocketServer to create and manage SSL connections."""
@ -1095,7 +1151,6 @@ class ThreadedTests(unittest.TestCase):
self.assertEqual(d1, ''.join(d2)) self.assertEqual(d1, ''.join(d2))
finally: finally:
server.stop() server.stop()
server.join()
def test_asyncore_server(self): def test_asyncore_server(self):
"""Check the example asyncore integration.""" """Check the example asyncore integration."""
@ -1130,8 +1185,6 @@ class ThreadedTests(unittest.TestCase):
s.close() s.close()
finally: finally:
server.stop() server.stop()
# wait for server thread to end
server.join()
def test_recv_send(self): def test_recv_send(self):
"""Test recv(), send() and friends.""" """Test recv(), send() and friends."""
@ -1246,7 +1299,6 @@ class ThreadedTests(unittest.TestCase):
s.close() s.close()
finally: finally:
server.stop() server.stop()
server.join()
def test_handshake_timeout(self): def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout # Issue #5103: SSL handshake must respect the socket timeout