Echo server connection timeout and monitoring
Prior to this change, the unit tests' echo servers' connections would sometimes linger beyond server termination. A timeout mechanism is now implemented, which will terminate a connection and clean up its resources when the timeout is reached. For the purpose of unit testing, test echo servers now assert that all connections have been terminated when the servers are closed. For threaded echo servers, the timeout mechanism involves the use of sockets with timeouts instead of blocking sockets. This required an implementation with the proper handling of timeout sockets at the sslconnection level.incoming
parent
821952b669
commit
9bee9d049f
|
@ -26,8 +26,10 @@ ERR_BOTH_KEY_CERT_FILES = 500
|
|||
ERR_BOTH_KEY_CERT_FILES_SVR = 298
|
||||
ERR_NO_CERTS = 331
|
||||
ERR_NO_CIPHER = 501
|
||||
ERR_HANDSHAKE_TIMEOUT = 502
|
||||
ERR_PORT_UNREACHABLE = 503
|
||||
ERR_READ_TIMEOUT = 502
|
||||
ERR_WRITE_TIMEOUT = 503
|
||||
ERR_HANDSHAKE_TIMEOUT = 504
|
||||
ERR_PORT_UNREACHABLE = 505
|
||||
ERR_COOKIE_MISMATCH = 0x1408A134
|
||||
|
||||
|
||||
|
@ -88,6 +90,8 @@ _ssl_errors = {
|
|||
ERR_BOTH_KEY_CERT_FILES_SVR: "Both the key & certificate files must be " + \
|
||||
"specified for server-side operation",
|
||||
ERR_NO_CIPHER: "No cipher can be selected.",
|
||||
ERR_READ_TIMEOUT: "The read operation timed out",
|
||||
ERR_WRITE_TIMEOUT: "The write operation timed out",
|
||||
ERR_HANDSHAKE_TIMEOUT: "The handshake operation timed out",
|
||||
ERR_PORT_UNREACHABLE: "The peer address is not reachable",
|
||||
}
|
||||
|
|
|
@ -26,14 +26,17 @@ library's ssl module, since its values can be passed to this module.
|
|||
import errno
|
||||
import socket
|
||||
import hmac
|
||||
import datetime
|
||||
from logging import getLogger
|
||||
from os import urandom
|
||||
from select import select
|
||||
from weakref import proxy
|
||||
from err import openssl_error, InvalidSocketError
|
||||
from err import raise_ssl_error
|
||||
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
|
||||
from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS
|
||||
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
|
||||
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
|
||||
from x509 import _X509, decode_cert
|
||||
from tlock import tlock_init
|
||||
from openssl import *
|
||||
|
@ -214,9 +217,35 @@ class SSLConnection(object):
|
|||
return lambda: self.do_handshake()
|
||||
|
||||
def _check_nbio(self):
|
||||
BIO_set_nbio(self._wbio.value, self._sock.gettimeout() is not None)
|
||||
timeout = self._sock.gettimeout()
|
||||
BIO_set_nbio(self._wbio.value, timeout is not None)
|
||||
if self._wbio is not self._rbio:
|
||||
BIO_set_nbio(self._rbio.value, self._rsock.gettimeout() is not None)
|
||||
timeout = self._rsock.gettimeout()
|
||||
BIO_set_nbio(self._rbio.value, timeout is not None)
|
||||
return timeout # read channel timeout
|
||||
|
||||
def _wrap_socket_library_call(self, call, timeout_error):
|
||||
timeout_sec_start = timeout_sec = self._check_nbio()
|
||||
# Pass the call if the socket is blocking or non-blocking
|
||||
if not timeout_sec: # None (blocking) or zero (non-blocking)
|
||||
return call()
|
||||
start_time = datetime.datetime.now()
|
||||
read_sock = self.get_socket(True)
|
||||
need_select = False
|
||||
while timeout_sec > 0:
|
||||
if need_select:
|
||||
if not select([read_sock], [], [], timeout_sec)[0]:
|
||||
break
|
||||
timeout_sec = timeout_sec_start - \
|
||||
(datetime.datetime.now() - start_time).total_seconds()
|
||||
try:
|
||||
return call()
|
||||
except openssl_error() as err:
|
||||
if err.ssl_error == SSL_ERROR_WANT_READ:
|
||||
need_select = True
|
||||
continue
|
||||
raise
|
||||
raise_ssl_error(timeout_error)
|
||||
|
||||
def _get_cookie(self, ssl):
|
||||
assert self._listening
|
||||
|
@ -426,14 +455,12 @@ class SSLConnection(object):
|
|||
"""
|
||||
|
||||
_logger.debug("Initiating handshake...")
|
||||
self._check_nbio()
|
||||
try:
|
||||
SSL_do_handshake(self._ssl.value)
|
||||
self._wrap_socket_library_call(
|
||||
lambda: SSL_do_handshake(self._ssl.value),
|
||||
ERR_HANDSHAKE_TIMEOUT)
|
||||
except openssl_error() as err:
|
||||
if err.ssl_error == SSL_ERROR_WANT_READ and \
|
||||
self.get_socket(True).gettimeout():
|
||||
raise_ssl_error(ERR_HANDSHAKE_TIMEOUT, err)
|
||||
elif err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
|
||||
if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
|
||||
raise_ssl_error(ERR_PORT_UNREACHABLE, err)
|
||||
raise
|
||||
self._handshake_done = True
|
||||
|
@ -450,8 +477,8 @@ class SSLConnection(object):
|
|||
string containing read bytes
|
||||
"""
|
||||
|
||||
self._check_nbio()
|
||||
return SSL_read(self._ssl.value, len)
|
||||
return self._wrap_socket_library_call(
|
||||
lambda: SSL_read(self._ssl.value, len), ERR_READ_TIMEOUT)
|
||||
|
||||
def write(self, data):
|
||||
"""Write data to connection
|
||||
|
@ -465,8 +492,8 @@ class SSLConnection(object):
|
|||
number of bytes actually transmitted
|
||||
"""
|
||||
|
||||
self._check_nbio()
|
||||
return SSL_write(self._ssl.value, data)
|
||||
return self._wrap_socket_library_call(
|
||||
lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shut down the DTLS connection
|
||||
|
@ -480,9 +507,9 @@ class SSLConnection(object):
|
|||
# Listening server-side sockets cannot be shut down
|
||||
return
|
||||
|
||||
self._check_nbio()
|
||||
try:
|
||||
SSL_shutdown(self._ssl.value)
|
||||
self._wrap_socket_library_call(
|
||||
lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT)
|
||||
except openssl_error() as err:
|
||||
if err.result == 0:
|
||||
# close-notify alert was just sent; wait for same from peer
|
||||
|
@ -490,7 +517,8 @@ class SSLConnection(object):
|
|||
# with SSL_set_read_ahead here, doing so causes a shutdown
|
||||
# failure (ret: -1, SSL_ERROR_SYSCALL) on the DTLS shutdown
|
||||
# initiator side. And test_starttls does pass.
|
||||
SSL_shutdown(self._ssl.value)
|
||||
self._wrap_socket_library_call(
|
||||
lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT)
|
||||
else:
|
||||
raise
|
||||
if hasattr(self, "_udp_demux"):
|
||||
|
|
|
@ -59,7 +59,7 @@ def main():
|
|||
try:
|
||||
conn.do_handshake()
|
||||
except SSLError as err:
|
||||
if len(err.args) > 1 and err.args[1].args[0] == SSL_ERROR_WANT_READ:
|
||||
if str(err).startswith("504:"):
|
||||
continue
|
||||
raise
|
||||
print "Completed handshaking with peer"
|
||||
|
@ -75,7 +75,7 @@ def main():
|
|||
try:
|
||||
message = conn.read()
|
||||
except SSLError as err:
|
||||
if err.args[0] == SSL_ERROR_WANT_READ:
|
||||
if str(err).startswith("502:"):
|
||||
continue
|
||||
if err.args[0] == SSL_ERROR_ZERO_RETURN:
|
||||
break
|
||||
|
@ -91,9 +91,10 @@ def main():
|
|||
assert not peer_address
|
||||
print "Shutdown invocation: %d" % cnt
|
||||
try:
|
||||
conn.shutdown()
|
||||
s = conn.shutdown()
|
||||
s.shutdown(socket.SHUT_RDWR)
|
||||
except SSLError as err:
|
||||
if err.args[0] == SSL_ERROR_WANT_READ:
|
||||
if str(err).startswith("502:"):
|
||||
continue
|
||||
raise
|
||||
break
|
||||
|
|
|
@ -15,14 +15,17 @@ import traceback
|
|||
import weakref
|
||||
import platform
|
||||
import threading
|
||||
import time
|
||||
import datetime
|
||||
import SocketServer
|
||||
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
||||
from collections import OrderedDict
|
||||
|
||||
import ssl
|
||||
from dtls import do_patch
|
||||
|
||||
HOST = "localhost"
|
||||
CONNECTION_TIMEOUT = datetime.timedelta(seconds=30)
|
||||
|
||||
class TestSupport(object):
|
||||
verbose = True
|
||||
|
@ -37,7 +40,6 @@ class TestSupport(object):
|
|||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.server.stop()
|
||||
self.server.join()
|
||||
self.server = None
|
||||
|
||||
def transient_internet(self):
|
||||
|
@ -134,7 +136,6 @@ class BasicSocketTests(unittest.TestCase):
|
|||
s.connect(remote)
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
@unittest.skipIf(platform.python_implementation() != "CPython",
|
||||
"Reference cycle test feasible under CPython only")
|
||||
|
@ -331,9 +332,10 @@ class ThreadedEchoServer(threading.Thread):
|
|||
self.server = server
|
||||
self.running = False
|
||||
self.sock = connsock
|
||||
self.sock.setblocking(True)
|
||||
self.sock.settimeout(CONNECTION_TIMEOUT.total_seconds())
|
||||
self.sslconn = connsock
|
||||
threading.Thread.__init__(self)
|
||||
server.register_handler(True)
|
||||
self.daemon = True
|
||||
|
||||
def show_conn_details(self):
|
||||
|
@ -389,6 +391,7 @@ class ThreadedEchoServer(threading.Thread):
|
|||
return self.sock.send(bytes)
|
||||
|
||||
def close(self):
|
||||
self.server.register_handler(False)
|
||||
if self.sslconn:
|
||||
self.sslconn.close()
|
||||
else:
|
||||
|
@ -486,6 +489,8 @@ class ThreadedEchoServer(threading.Thread):
|
|||
self.starttls_server = starttls_server
|
||||
self.sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
|
||||
self.flag = None
|
||||
self.num_handlers = 0
|
||||
self.num_handlers_lock = threading.Lock()
|
||||
self.sock = ssl.wrap_socket(self.sock, server_side=True,
|
||||
certfile=self.certificate,
|
||||
cert_reqs=self.certreqs,
|
||||
|
@ -504,6 +509,7 @@ class ThreadedEchoServer(threading.Thread):
|
|||
|
||||
def start(self, flag=None):
|
||||
self.flag = flag
|
||||
self.starter = threading.current_thread().ident
|
||||
threading.Thread.start(self)
|
||||
|
||||
def run(self):
|
||||
|
@ -529,8 +535,26 @@ class ThreadedEchoServer(threading.Thread):
|
|||
self.stop()
|
||||
self.sock.close()
|
||||
|
||||
def register_handler(self, add):
|
||||
with self.num_handlers_lock:
|
||||
if add:
|
||||
self.num_handlers += 1
|
||||
else:
|
||||
self.num_handlers -= 1
|
||||
assert self.num_handlers >= 0
|
||||
|
||||
def stop(self):
|
||||
self.active = False
|
||||
if self.starter != threading.current_thread().ident:
|
||||
return
|
||||
self.join() # don't allow spawning new handlers after we've checked
|
||||
last_msg = datetime.datetime.now()
|
||||
while self.num_handlers:
|
||||
time.sleep(0.05)
|
||||
now = datetime.datetime.now()
|
||||
if now > last_msg + datetime.timedelta(seconds=1):
|
||||
sys.stdout.write(' server: waiting for connections to close\n')
|
||||
last_msg = now
|
||||
|
||||
class AsyncoreEchoServer(threading.Thread):
|
||||
|
||||
|
@ -538,9 +562,10 @@ class AsyncoreEchoServer(threading.Thread):
|
|||
|
||||
class ConnectionHandler(asyncore.dispatcher):
|
||||
|
||||
def __init__(self, conn, timeout_tracker):
|
||||
def __init__(self, conn, timeout_tracker, server):
|
||||
asyncore.dispatcher.__init__(self, conn)
|
||||
self._timeout_tracker = timeout_tracker
|
||||
self._server = server
|
||||
self._ssl_accepting = True
|
||||
# Complete the handshake
|
||||
self.handle_read_event()
|
||||
|
@ -585,6 +610,11 @@ class AsyncoreEchoServer(threading.Thread):
|
|||
data = self.recv(1024)
|
||||
if data and data.strip() != 'over':
|
||||
self.send(data.lower())
|
||||
if self.connected:
|
||||
self._server.reset_timeout(self)
|
||||
self._server.check_timeout()
|
||||
if not self.connected: # above called handle_close
|
||||
return
|
||||
delta = self.socket.get_timeout()
|
||||
if delta:
|
||||
self._timeout_tracker[self] = \
|
||||
|
@ -593,6 +623,7 @@ class AsyncoreEchoServer(threading.Thread):
|
|||
def handle_close(self):
|
||||
if self._timeout_tracker.has_key(self):
|
||||
self._timeout_tracker.pop(self)
|
||||
self._server._handlers.pop(self)
|
||||
self.close()
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" server: closed connection %s\n" %
|
||||
|
@ -604,6 +635,7 @@ class AsyncoreEchoServer(threading.Thread):
|
|||
def __init__(self, certfile, timeout_tracker):
|
||||
asyncore.dispatcher.__init__(self)
|
||||
self._timeout_tracker = timeout_tracker
|
||||
self._handlers = OrderedDict()
|
||||
sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
sock.bind((HOST, 0))
|
||||
|
@ -618,17 +650,43 @@ class AsyncoreEchoServer(threading.Thread):
|
|||
return False
|
||||
|
||||
def handle_accept(self):
|
||||
self.check_timeout()
|
||||
acc_ret = self.accept()
|
||||
if acc_ret:
|
||||
sock_obj, addr = acc_ret
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" server: new connection from " +
|
||||
"%s:%s\n" % (addr[0], str(addr[1:])))
|
||||
self.ConnectionHandler(sock_obj, self._timeout_tracker)
|
||||
self._handlers[self.ConnectionHandler(sock_obj,
|
||||
self._timeout_tracker,
|
||||
self)] = \
|
||||
datetime.datetime.now()
|
||||
|
||||
def handle_error(self):
|
||||
raise
|
||||
|
||||
def reset_timeout(self, handler):
|
||||
if self._handlers.has_key(handler):
|
||||
self._handlers.pop(handler)
|
||||
self._handlers[handler] = datetime.datetime.now()
|
||||
|
||||
def check_timeout(self):
|
||||
now = datetime.datetime.now()
|
||||
while True:
|
||||
try:
|
||||
handler = self._handlers.__iter__().next() # oldest handler
|
||||
except StopIteration:
|
||||
break # there are no more handlers
|
||||
if now > self._handlers[handler] + CONNECTION_TIMEOUT:
|
||||
handler.handle_close()
|
||||
else:
|
||||
break # the oldest handlers has not yet timed out
|
||||
|
||||
def close(self):
|
||||
map(lambda x: x.handle_close(), self._handlers.keys())
|
||||
assert not self._handlers
|
||||
asyncore.dispatcher.close(self)
|
||||
|
||||
def __init__(self, certfile):
|
||||
self.flag = None
|
||||
self.active = False
|
||||
|
@ -660,6 +718,7 @@ class AsyncoreEchoServer(threading.Thread):
|
|||
|
||||
def stop(self):
|
||||
self.active = False
|
||||
self.join()
|
||||
self.server.close()
|
||||
|
||||
# Note that this HTTP-over-UDP server does not implement packet recovery and
|
||||
|
@ -800,7 +859,6 @@ def bad_cert_test(certfile):
|
|||
raise AssertionError("Use of invalid cert should have failed!")
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def server_params_test(certfile, protocol, certreqs, cacertsfile,
|
||||
client_certfile, client_protocol=None,
|
||||
|
@ -854,7 +912,6 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
|
|||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def try_protocol_combo(server_protocol,
|
||||
client_protocol,
|
||||
|
@ -950,10 +1007,10 @@ class ThreadedTests(unittest.TestCase):
|
|||
self.fail(
|
||||
"Missing or invalid 'organizationName' field in "
|
||||
"certificate subject; should be 'Ray Srv Inc'.")
|
||||
s.write("over\n")
|
||||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def test_empty_cert(self):
|
||||
"""Connecting with an empty cert file"""
|
||||
|
@ -1045,7 +1102,6 @@ class ThreadedTests(unittest.TestCase):
|
|||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def test_socketserver(self):
|
||||
"""Using a SocketServer to create and manage SSL connections."""
|
||||
|
@ -1095,7 +1151,6 @@ class ThreadedTests(unittest.TestCase):
|
|||
self.assertEqual(d1, ''.join(d2))
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def test_asyncore_server(self):
|
||||
"""Check the example asyncore integration."""
|
||||
|
@ -1130,8 +1185,6 @@ class ThreadedTests(unittest.TestCase):
|
|||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
# wait for server thread to end
|
||||
server.join()
|
||||
|
||||
def test_recv_send(self):
|
||||
"""Test recv(), send() and friends."""
|
||||
|
@ -1246,7 +1299,6 @@ class ThreadedTests(unittest.TestCase):
|
|||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
def test_handshake_timeout(self):
|
||||
# Issue #5103: SSL handshake must respect the socket timeout
|
||||
|
|
Loading…
Reference in New Issue