Echo server connection timeout and monitoring

Prior to this change, the unit tests' echo servers' connections would sometimes
linger beyond server termination. A timeout mechanism is now implemented, which
will terminate a connection and clean up its resources when the timeout is
reached. For the purpose of unit testing, test echo servers now assert that
all connections have been terminated when the servers are closed.

For threaded echo servers, the timeout mechanism involves the use of sockets
with timeouts instead of blocking sockets. This required an implementation with
the proper handling of timeout sockets at the sslconnection level.
incoming
Ray Brown 2012-11-25 13:44:55 -08:00
parent 821952b669
commit 9bee9d049f
4 changed files with 119 additions and 34 deletions

View File

@ -26,8 +26,10 @@ ERR_BOTH_KEY_CERT_FILES = 500
ERR_BOTH_KEY_CERT_FILES_SVR = 298
ERR_NO_CERTS = 331
ERR_NO_CIPHER = 501
ERR_HANDSHAKE_TIMEOUT = 502
ERR_PORT_UNREACHABLE = 503
ERR_READ_TIMEOUT = 502
ERR_WRITE_TIMEOUT = 503
ERR_HANDSHAKE_TIMEOUT = 504
ERR_PORT_UNREACHABLE = 505
ERR_COOKIE_MISMATCH = 0x1408A134
@ -88,6 +90,8 @@ _ssl_errors = {
ERR_BOTH_KEY_CERT_FILES_SVR: "Both the key & certificate files must be " + \
"specified for server-side operation",
ERR_NO_CIPHER: "No cipher can be selected.",
ERR_READ_TIMEOUT: "The read operation timed out",
ERR_WRITE_TIMEOUT: "The write operation timed out",
ERR_HANDSHAKE_TIMEOUT: "The handshake operation timed out",
ERR_PORT_UNREACHABLE: "The peer address is not reachable",
}

View File

@ -26,14 +26,17 @@ library's ssl module, since its values can be passed to this module.
import errno
import socket
import hmac
import datetime
from logging import getLogger
from os import urandom
from select import select
from weakref import proxy
from err import openssl_error, InvalidSocketError
from err import raise_ssl_error
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from x509 import _X509, decode_cert
from tlock import tlock_init
from openssl import *
@ -214,9 +217,35 @@ class SSLConnection(object):
return lambda: self.do_handshake()
def _check_nbio(self):
BIO_set_nbio(self._wbio.value, self._sock.gettimeout() is not None)
timeout = self._sock.gettimeout()
BIO_set_nbio(self._wbio.value, timeout is not None)
if self._wbio is not self._rbio:
BIO_set_nbio(self._rbio.value, self._rsock.gettimeout() is not None)
timeout = self._rsock.gettimeout()
BIO_set_nbio(self._rbio.value, timeout is not None)
return timeout # read channel timeout
def _wrap_socket_library_call(self, call, timeout_error):
timeout_sec_start = timeout_sec = self._check_nbio()
# Pass the call if the socket is blocking or non-blocking
if not timeout_sec: # None (blocking) or zero (non-blocking)
return call()
start_time = datetime.datetime.now()
read_sock = self.get_socket(True)
need_select = False
while timeout_sec > 0:
if need_select:
if not select([read_sock], [], [], timeout_sec)[0]:
break
timeout_sec = timeout_sec_start - \
(datetime.datetime.now() - start_time).total_seconds()
try:
return call()
except openssl_error() as err:
if err.ssl_error == SSL_ERROR_WANT_READ:
need_select = True
continue
raise
raise_ssl_error(timeout_error)
def _get_cookie(self, ssl):
assert self._listening
@ -426,14 +455,12 @@ class SSLConnection(object):
"""
_logger.debug("Initiating handshake...")
self._check_nbio()
try:
SSL_do_handshake(self._ssl.value)
self._wrap_socket_library_call(
lambda: SSL_do_handshake(self._ssl.value),
ERR_HANDSHAKE_TIMEOUT)
except openssl_error() as err:
if err.ssl_error == SSL_ERROR_WANT_READ and \
self.get_socket(True).gettimeout():
raise_ssl_error(ERR_HANDSHAKE_TIMEOUT, err)
elif err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err)
raise
self._handshake_done = True
@ -450,8 +477,8 @@ class SSLConnection(object):
string containing read bytes
"""
self._check_nbio()
return SSL_read(self._ssl.value, len)
return self._wrap_socket_library_call(
lambda: SSL_read(self._ssl.value, len), ERR_READ_TIMEOUT)
def write(self, data):
"""Write data to connection
@ -465,8 +492,8 @@ class SSLConnection(object):
number of bytes actually transmitted
"""
self._check_nbio()
return SSL_write(self._ssl.value, data)
return self._wrap_socket_library_call(
lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT)
def shutdown(self):
"""Shut down the DTLS connection
@ -480,9 +507,9 @@ class SSLConnection(object):
# Listening server-side sockets cannot be shut down
return
self._check_nbio()
try:
SSL_shutdown(self._ssl.value)
self._wrap_socket_library_call(
lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT)
except openssl_error() as err:
if err.result == 0:
# close-notify alert was just sent; wait for same from peer
@ -490,7 +517,8 @@ class SSLConnection(object):
# with SSL_set_read_ahead here, doing so causes a shutdown
# failure (ret: -1, SSL_ERROR_SYSCALL) on the DTLS shutdown
# initiator side. And test_starttls does pass.
SSL_shutdown(self._ssl.value)
self._wrap_socket_library_call(
lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT)
else:
raise
if hasattr(self, "_udp_demux"):

View File

@ -59,7 +59,7 @@ def main():
try:
conn.do_handshake()
except SSLError as err:
if len(err.args) > 1 and err.args[1].args[0] == SSL_ERROR_WANT_READ:
if str(err).startswith("504:"):
continue
raise
print "Completed handshaking with peer"
@ -75,7 +75,7 @@ def main():
try:
message = conn.read()
except SSLError as err:
if err.args[0] == SSL_ERROR_WANT_READ:
if str(err).startswith("502:"):
continue
if err.args[0] == SSL_ERROR_ZERO_RETURN:
break
@ -91,9 +91,10 @@ def main():
assert not peer_address
print "Shutdown invocation: %d" % cnt
try:
conn.shutdown()
s = conn.shutdown()
s.shutdown(socket.SHUT_RDWR)
except SSLError as err:
if err.args[0] == SSL_ERROR_WANT_READ:
if str(err).startswith("502:"):
continue
raise
break

View File

@ -15,14 +15,17 @@ import traceback
import weakref
import platform
import threading
import time
import datetime
import SocketServer
from SimpleHTTPServer import SimpleHTTPRequestHandler
from collections import OrderedDict
import ssl
from dtls import do_patch
HOST = "localhost"
CONNECTION_TIMEOUT = datetime.timedelta(seconds=30)
class TestSupport(object):
verbose = True
@ -37,7 +40,6 @@ class TestSupport(object):
def __exit__(self, exc_type, exc_value, traceback):
self.server.stop()
self.server.join()
self.server = None
def transient_internet(self):
@ -134,7 +136,6 @@ class BasicSocketTests(unittest.TestCase):
s.connect(remote)
finally:
server.stop()
server.join()
@unittest.skipIf(platform.python_implementation() != "CPython",
"Reference cycle test feasible under CPython only")
@ -331,9 +332,10 @@ class ThreadedEchoServer(threading.Thread):
self.server = server
self.running = False
self.sock = connsock
self.sock.setblocking(True)
self.sock.settimeout(CONNECTION_TIMEOUT.total_seconds())
self.sslconn = connsock
threading.Thread.__init__(self)
server.register_handler(True)
self.daemon = True
def show_conn_details(self):
@ -389,6 +391,7 @@ class ThreadedEchoServer(threading.Thread):
return self.sock.send(bytes)
def close(self):
self.server.register_handler(False)
if self.sslconn:
self.sslconn.close()
else:
@ -486,6 +489,8 @@ class ThreadedEchoServer(threading.Thread):
self.starttls_server = starttls_server
self.sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
self.flag = None
self.num_handlers = 0
self.num_handlers_lock = threading.Lock()
self.sock = ssl.wrap_socket(self.sock, server_side=True,
certfile=self.certificate,
cert_reqs=self.certreqs,
@ -504,6 +509,7 @@ class ThreadedEchoServer(threading.Thread):
def start(self, flag=None):
self.flag = flag
self.starter = threading.current_thread().ident
threading.Thread.start(self)
def run(self):
@ -529,8 +535,26 @@ class ThreadedEchoServer(threading.Thread):
self.stop()
self.sock.close()
def register_handler(self, add):
with self.num_handlers_lock:
if add:
self.num_handlers += 1
else:
self.num_handlers -= 1
assert self.num_handlers >= 0
def stop(self):
self.active = False
if self.starter != threading.current_thread().ident:
return
self.join() # don't allow spawning new handlers after we've checked
last_msg = datetime.datetime.now()
while self.num_handlers:
time.sleep(0.05)
now = datetime.datetime.now()
if now > last_msg + datetime.timedelta(seconds=1):
sys.stdout.write(' server: waiting for connections to close\n')
last_msg = now
class AsyncoreEchoServer(threading.Thread):
@ -538,9 +562,10 @@ class AsyncoreEchoServer(threading.Thread):
class ConnectionHandler(asyncore.dispatcher):
def __init__(self, conn, timeout_tracker):
def __init__(self, conn, timeout_tracker, server):
asyncore.dispatcher.__init__(self, conn)
self._timeout_tracker = timeout_tracker
self._server = server
self._ssl_accepting = True
# Complete the handshake
self.handle_read_event()
@ -585,6 +610,11 @@ class AsyncoreEchoServer(threading.Thread):
data = self.recv(1024)
if data and data.strip() != 'over':
self.send(data.lower())
if self.connected:
self._server.reset_timeout(self)
self._server.check_timeout()
if not self.connected: # above called handle_close
return
delta = self.socket.get_timeout()
if delta:
self._timeout_tracker[self] = \
@ -593,6 +623,7 @@ class AsyncoreEchoServer(threading.Thread):
def handle_close(self):
if self._timeout_tracker.has_key(self):
self._timeout_tracker.pop(self)
self._server._handlers.pop(self)
self.close()
if test_support.verbose:
sys.stdout.write(" server: closed connection %s\n" %
@ -604,6 +635,7 @@ class AsyncoreEchoServer(threading.Thread):
def __init__(self, certfile, timeout_tracker):
asyncore.dispatcher.__init__(self)
self._timeout_tracker = timeout_tracker
self._handlers = OrderedDict()
sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind((HOST, 0))
@ -618,17 +650,43 @@ class AsyncoreEchoServer(threading.Thread):
return False
def handle_accept(self):
self.check_timeout()
acc_ret = self.accept()
if acc_ret:
sock_obj, addr = acc_ret
if test_support.verbose:
sys.stdout.write(" server: new connection from " +
"%s:%s\n" % (addr[0], str(addr[1:])))
self.ConnectionHandler(sock_obj, self._timeout_tracker)
self._handlers[self.ConnectionHandler(sock_obj,
self._timeout_tracker,
self)] = \
datetime.datetime.now()
def handle_error(self):
raise
def reset_timeout(self, handler):
if self._handlers.has_key(handler):
self._handlers.pop(handler)
self._handlers[handler] = datetime.datetime.now()
def check_timeout(self):
now = datetime.datetime.now()
while True:
try:
handler = self._handlers.__iter__().next() # oldest handler
except StopIteration:
break # there are no more handlers
if now > self._handlers[handler] + CONNECTION_TIMEOUT:
handler.handle_close()
else:
break # the oldest handlers has not yet timed out
def close(self):
map(lambda x: x.handle_close(), self._handlers.keys())
assert not self._handlers
asyncore.dispatcher.close(self)
def __init__(self, certfile):
self.flag = None
self.active = False
@ -660,6 +718,7 @@ class AsyncoreEchoServer(threading.Thread):
def stop(self):
self.active = False
self.join()
self.server.close()
# Note that this HTTP-over-UDP server does not implement packet recovery and
@ -800,7 +859,6 @@ def bad_cert_test(certfile):
raise AssertionError("Use of invalid cert should have failed!")
finally:
server.stop()
server.join()
def server_params_test(certfile, protocol, certreqs, cacertsfile,
client_certfile, client_protocol=None,
@ -854,7 +912,6 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
s.close()
finally:
server.stop()
server.join()
def try_protocol_combo(server_protocol,
client_protocol,
@ -950,10 +1007,10 @@ class ThreadedTests(unittest.TestCase):
self.fail(
"Missing or invalid 'organizationName' field in "
"certificate subject; should be 'Ray Srv Inc'.")
s.write("over\n")
s.close()
finally:
server.stop()
server.join()
def test_empty_cert(self):
"""Connecting with an empty cert file"""
@ -1045,7 +1102,6 @@ class ThreadedTests(unittest.TestCase):
s.close()
finally:
server.stop()
server.join()
def test_socketserver(self):
"""Using a SocketServer to create and manage SSL connections."""
@ -1095,7 +1151,6 @@ class ThreadedTests(unittest.TestCase):
self.assertEqual(d1, ''.join(d2))
finally:
server.stop()
server.join()
def test_asyncore_server(self):
"""Check the example asyncore integration."""
@ -1130,8 +1185,6 @@ class ThreadedTests(unittest.TestCase):
s.close()
finally:
server.stop()
# wait for server thread to end
server.join()
def test_recv_send(self):
"""Test recv(), send() and friends."""
@ -1246,7 +1299,6 @@ class ThreadedTests(unittest.TestCase):
s.close()
finally:
server.stop()
server.join()
def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout