Add IPv6 unit tests

The unit test suite was previously being run with IPv4 protocol addresses only.
With this change, we run the entire test suite twice: first with v4 addresses,
and then with v6 addresses, for all client and server-side sockets.
incoming
Ray Brown 2012-11-21 19:20:16 -08:00
parent 22083e8221
commit 1ce7243af5
2 changed files with 43 additions and 51 deletions

View File

@ -18,7 +18,7 @@ has the following effects:
""" """
from socket import SOCK_DGRAM, socket, _delegate_methods, error as socket_error from socket import SOCK_DGRAM, socket, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_DGRAM from socket import AF_INET, SOCK_DGRAM, getaddrinfo
from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION
from sslconnection import DTLS_OPENSSL_VERSION_INFO from sslconnection import DTLS_OPENSSL_VERSION_INFO
@ -58,12 +58,12 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
if ssl_version != PROTOCOL_DTLSv1: if ssl_version != PROTOCOL_DTLSv1:
return _orig_get_server_certificate(addr, ssl_version, ca_certs) return _orig_get_server_certificate(addr, ssl_version, ca_certs)
host, port = addr
if (ca_certs is not None): if (ca_certs is not None):
cert_reqs = ssl.CERT_REQUIRED cert_reqs = ssl.CERT_REQUIRED
else: else:
cert_reqs = ssl.CERT_NONE cert_reqs = ssl.CERT_NONE
s = ssl.wrap_socket(socket(AF_INET, SOCK_DGRAM), af = getaddrinfo(addr[0], addr[1])[0][0]
s = ssl.wrap_socket(socket(af, SOCK_DGRAM),
ssl_version=ssl_version, ssl_version=ssl_version,
cert_reqs=cert_reqs, ca_certs=ca_certs) cert_reqs=cert_reqs, ca_certs=ca_certs)
s.connect(addr) s.connect(addr)

View File

@ -56,14 +56,14 @@ class BasicTests(unittest.TestCase):
def test_sslwrap_simple(self): def test_sslwrap_simple(self):
# A crude test for the legacy API # A crude test for the legacy API
try: try:
ssl.sslwrap_simple(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) ssl.sslwrap_simple(socket.socket(AF_INET4_6, socket.SOCK_DGRAM))
except IOError, e: except IOError, e:
if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
pass pass
else: else:
raise raise
try: try:
ssl.sslwrap_simple(socket.socket(socket.AF_INET, ssl.sslwrap_simple(socket.socket(AF_INET4_6,
socket.SOCK_DGRAM)._sock) socket.SOCK_DGRAM)._sock)
except IOError, e: except IOError, e:
if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
@ -118,18 +118,15 @@ class BasicSocketTests(unittest.TestCase):
flag.wait() flag.wait()
remote = (HOST, server.port) remote = (HOST, server.port)
try: try:
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
cert_reqs=ssl.CERT_NONE, ciphers="ALL") cert_reqs=ssl.CERT_NONE, ciphers="ALL")
s.connect(remote) s.connect(remote)
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT")
s.connect(remote) s.connect(remote)
# Error checking occurs when connecting, because the SSL context # Error checking occurs when connecting, because the SSL context
# isn't created before. # isn't created before.
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
cert_reqs=ssl.CERT_NONE, cert_reqs=ssl.CERT_NONE,
ciphers="^$:,;?*'dorothyx") ciphers="^$:,;?*'dorothyx")
with self.assertRaisesRegexp(ssl.SSLError, with self.assertRaisesRegexp(ssl.SSLError,
@ -138,14 +135,13 @@ class BasicSocketTests(unittest.TestCase):
finally: finally:
server.stop() server.stop()
server.join() server.join()
# repeat with AF_INET6?
@unittest.skipIf(platform.python_implementation() != "CPython", @unittest.skipIf(platform.python_implementation() != "CPython",
"Reference cycle test feasible under CPython only") "Reference cycle test feasible under CPython only")
def test_refcycle(self): def test_refcycle(self):
# Issue #7943: an SSL object doesn't create reference cycles with # Issue #7943: an SSL object doesn't create reference cycles with
# itself. # itself.
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
ss = ssl.wrap_socket(s) ss = ssl.wrap_socket(s)
wr = weakref.ref(ss) wr = weakref.ref(ss)
del ss del ss
@ -155,22 +151,23 @@ class BasicSocketTests(unittest.TestCase):
# The _delegate_methods in socket.py are correctly delegated to by an # The _delegate_methods in socket.py are correctly delegated to by an
# unconnected SSLSocket, so they will raise a socket.error rather than # unconnected SSLSocket, so they will raise a socket.error rather than
# something unexpected like TypeError. # something unexpected like TypeError.
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
ss = ssl.wrap_socket(s) ss = ssl.wrap_socket(s)
self.assertRaises(socket.error, ss.recv, 1) self.assertRaises(socket.error, ss.recv, 1)
self.assertRaises(socket.error, ss.recv_into, bytearray(b'x')) self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
self.assertRaises(socket.error, ss.recvfrom, 1) self.assertRaises(socket.error, ss.recvfrom, 1)
self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1) self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
self.assertRaises(socket.error, ss.send, b'x') self.assertRaises(socket.error, ss.send, b'x')
self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0)) self.assertRaises(socket.error, ss.sendto, b'x',
('0.0.0.0', 0) if AF_INET4_6 == socket.AF_INET else
('::', 0))
class NetworkedTests(unittest.TestCase): class NetworkedTests(unittest.TestCase):
def test_connect(self): def test_connect(self):
with test_support.transient_internet() as remote: with test_support.transient_internet() as remote:
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
cert_reqs=ssl.CERT_NONE) cert_reqs=ssl.CERT_NONE)
s.connect(remote) s.connect(remote)
c = s.getpeercert() c = s.getpeercert()
@ -179,8 +176,7 @@ class NetworkedTests(unittest.TestCase):
s.close() s.close()
# this should fail because we have no verification certs # this should fail because we have no verification certs
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
cert_reqs=ssl.CERT_REQUIRED) cert_reqs=ssl.CERT_REQUIRED)
try: try:
s.connect(remote) s.connect(remote)
@ -190,8 +186,7 @@ class NetworkedTests(unittest.TestCase):
s.close() s.close()
# this should succeed because we specify the root cert # this should succeed because we specify the root cert
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
cert_reqs=ssl.CERT_REQUIRED, cert_reqs=ssl.CERT_REQUIRED,
ca_certs=ISSUER_CERTFILE) ca_certs=ISSUER_CERTFILE)
try: try:
@ -202,8 +197,7 @@ class NetworkedTests(unittest.TestCase):
def test_connect_ex(self): def test_connect_ex(self):
# Issue #11326: check connect_ex() implementation # Issue #11326: check connect_ex() implementation
with test_support.transient_internet() as remote: with test_support.transient_internet() as remote:
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
cert_reqs=ssl.CERT_REQUIRED, cert_reqs=ssl.CERT_REQUIRED,
ca_certs=ISSUER_CERTFILE) ca_certs=ISSUER_CERTFILE)
try: try:
@ -216,8 +210,7 @@ class NetworkedTests(unittest.TestCase):
# Issue #11326: non-blocking connect_ex() should allow handshake # Issue #11326: non-blocking connect_ex() should allow handshake
# to proceed after the socket gets ready. # to proceed after the socket gets ready.
with test_support.transient_internet() as remote: with test_support.transient_internet() as remote:
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
cert_reqs=ssl.CERT_REQUIRED, cert_reqs=ssl.CERT_REQUIRED,
ca_certs=ISSUER_CERTFILE, ca_certs=ISSUER_CERTFILE,
do_handshake_on_connect=False) do_handshake_on_connect=False)
@ -254,8 +247,7 @@ class NetworkedTests(unittest.TestCase):
# delay closing the underlying "real socket" (here tested with its # delay closing the underlying "real socket" (here tested with its
# file descriptor, hence skipping the test under Windows). # file descriptor, hence skipping the test under Windows).
with test_support.transient_internet() as remote: with test_support.transient_internet() as remote:
ss = ssl.wrap_socket(socket.socket(socket.AF_INET, ss = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM))
socket.SOCK_DGRAM))
ss.connect(remote) ss.connect(remote)
fd = ss.fileno() fd = ss.fileno()
f = ss.makefile() f = ss.makefile()
@ -271,7 +263,7 @@ class NetworkedTests(unittest.TestCase):
def test_non_blocking_handshake(self): def test_non_blocking_handshake(self):
with test_support.transient_internet() as remote: with test_support.transient_internet() as remote:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
s.connect(remote) s.connect(remote)
s.setblocking(False) s.setblocking(False)
s = ssl.wrap_socket(s, s = ssl.wrap_socket(s,
@ -492,7 +484,7 @@ class ThreadedEchoServer(threading.Thread):
self.chatty = chatty self.chatty = chatty
self.connectionchatty = connectionchatty self.connectionchatty = connectionchatty
self.starttls_server = starttls_server self.starttls_server = starttls_server
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
self.flag = None self.flag = None
self.sock = ssl.wrap_socket(self.sock, server_side=True, self.sock = ssl.wrap_socket(self.sock, server_side=True,
certfile=self.certificate, certfile=self.certificate,
@ -612,7 +604,7 @@ class AsyncoreEchoServer(threading.Thread):
def __init__(self, certfile, timeout_tracker): def __init__(self, certfile, timeout_tracker):
asyncore.dispatcher.__init__(self) asyncore.dispatcher.__init__(self)
self._timeout_tracker = timeout_tracker self._timeout_tracker = timeout_tracker
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
sock.setblocking(False) sock.setblocking(False)
sock.bind((HOST, 0)) sock.bind((HOST, 0))
self.sockname = sock.getsockname() self.sockname = sock.getsockname()
@ -631,7 +623,7 @@ class AsyncoreEchoServer(threading.Thread):
sock_obj, addr = acc_ret sock_obj, addr = acc_ret
if test_support.verbose: if test_support.verbose:
sys.stdout.write(" server: new connection from " + sys.stdout.write(" server: new connection from " +
"%s:%s\n" %addr) "%s:%s\n" % (addr[0], str(addr[1:])))
self.ConnectionHandler(sock_obj, self._timeout_tracker) self.ConnectionHandler(sock_obj, self._timeout_tracker)
def handle_error(self): def handle_error(self):
@ -680,7 +672,7 @@ class SocketServerHTTPSServer(threading.Thread):
SocketServer.ThreadingTCPServer.__init__(self, server_address, SocketServer.ThreadingTCPServer.__init__(self, server_address,
RequestHandlerClass, False) RequestHandlerClass, False)
# account for dealing with a datagram socket # account for dealing with a datagram socket
self.socket = ssl.wrap_socket(socket.socket(socket.AF_INET, self.socket = ssl.wrap_socket(socket.socket(AF_INET4_6,
socket.SOCK_DGRAM), socket.SOCK_DGRAM),
server_side=True, server_side=True,
certfile=certfile, certfile=certfile,
@ -794,8 +786,7 @@ def bad_cert_test(certfile):
# try to connect # try to connect
try: try:
try: try:
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
certfile=certfile, certfile=certfile,
ssl_version=ssl.PROTOCOL_DTLSv1) ssl_version=ssl.PROTOCOL_DTLSv1)
s.connect((HOST, server.port)) s.connect((HOST, server.port))
@ -834,8 +825,7 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
if client_protocol is None: if client_protocol is None:
client_protocol = protocol client_protocol = protocol
try: try:
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
certfile=client_certfile, certfile=client_certfile,
ca_certs=cacertsfile, ca_certs=cacertsfile,
ciphers=ciphers, ciphers=ciphers,
@ -909,11 +899,11 @@ def try_protocol_combo(server_protocol,
class ThreadedTests(unittest.TestCase): class ThreadedTests(unittest.TestCase):
def test_unreachable(self): def test_unreachable(self):
server = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) server = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
server.bind((HOST, 0)) server.bind((HOST, 0))
port = server.getsockname()[1] port = server.getsockname()[1]
server.close() server.close()
s = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM))
self.assertRaisesRegexp(ssl.SSLError, self.assertRaisesRegexp(ssl.SSLError,
"The peer address is not reachable", "The peer address is not reachable",
s.connect, (HOST, port)) s.connect, (HOST, port))
@ -940,8 +930,7 @@ class ThreadedTests(unittest.TestCase):
flag.wait() flag.wait()
# try to connect # try to connect
try: try:
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
socket.SOCK_DGRAM),
certfile=CERTFILE, certfile=CERTFILE,
ca_certs=ISSUER_CERTFILE, ca_certs=ISSUER_CERTFILE,
cert_reqs=ssl.CERT_REQUIRED, cert_reqs=ssl.CERT_REQUIRED,
@ -1010,8 +999,7 @@ class ThreadedTests(unittest.TestCase):
# try to connect # try to connect
wrapped = False wrapped = False
try: try:
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM))
socket.SOCK_DGRAM))
s.connect((HOST, server.port)) s.connect((HOST, server.port))
s = s.unwrap() s = s.unwrap()
if test_support.verbose: if test_support.verbose:
@ -1074,8 +1062,7 @@ class ThreadedTests(unittest.TestCase):
d1 = f.read() d1 = f.read()
d2 = [] d2 = []
# now fetch the same data from the HTTPS-UDP server # now fetch the same data from the HTTPS-UDP server
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM))
socket.SOCK_DGRAM))
s.connect((HOST, server.port)) s.connect((HOST, server.port))
fl = "/" + os.path.split(CERTFILE)[1] fl = "/" + os.path.split(CERTFILE)[1]
s.write("GET " + fl + " HTTP/1.1\r\n" + s.write("GET " + fl + " HTTP/1.1\r\n" +
@ -1123,8 +1110,7 @@ class ThreadedTests(unittest.TestCase):
flag.wait() flag.wait()
# try to connect # try to connect
try: try:
s = ssl.wrap_socket(socket.socket(socket.AF_INET, s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM))
socket.SOCK_DGRAM))
s.connect((HOST, server.port)) s.connect((HOST, server.port))
if test_support.verbose: if test_support.verbose:
sys.stdout.write( sys.stdout.write(
@ -1163,7 +1149,7 @@ class ThreadedTests(unittest.TestCase):
# wait for it to start # wait for it to start
flag.wait() flag.wait()
# try to connect # try to connect
s = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM), s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM),
server_side=False, server_side=False,
certfile=CERTFILE, certfile=CERTFILE,
ca_certs=CERTFILE, ca_certs=CERTFILE,
@ -1264,13 +1250,13 @@ class ThreadedTests(unittest.TestCase):
def test_handshake_timeout(self): def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout # Issue #5103: SSL handshake must respect the socket timeout
server = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) server = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
server.bind((HOST, 0)) server.bind((HOST, 0))
port = server.getsockname()[1] port = server.getsockname()[1]
try: try:
try: try:
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) c = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
c.settimeout(0.2) c.settimeout(0.2)
c.connect((HOST, port)) c.connect((HOST, port))
# Will attempt handshake and time out # Will attempt handshake and time out
@ -1279,7 +1265,7 @@ class ThreadedTests(unittest.TestCase):
finally: finally:
c.close() c.close()
try: try:
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) c = socket.socket(AF_INET4_6, socket.SOCK_DGRAM)
c.settimeout(0.2) c.settimeout(0.2)
c = ssl.wrap_socket(c) c = ssl.wrap_socket(c)
# Will attempt handshake and time out # Will attempt handshake and time out
@ -1292,7 +1278,7 @@ class ThreadedTests(unittest.TestCase):
def test_main(verbose=True): def test_main(verbose=True):
global CERTFILE, ISSUER_CERTFILE, OTHER_CERTFILE global CERTFILE, ISSUER_CERTFILE, OTHER_CERTFILE, AF_INET4_6
CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir,
"certs", "keycert.pem") "certs", "keycert.pem")
ISSUER_CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, ISSUER_CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir,
@ -1306,6 +1292,12 @@ def test_main(verbose=True):
TestSupport.verbose = verbose TestSupport.verbose = verbose
do_patch() do_patch()
AF_INET4_6 = socket.AF_INET
res = unittest.main(exit=False).result.wasSuccessful()
if not res:
print "IPv4 test suite failed; not proceeding to IPv6"
sys.exit(not res)
AF_INET4_6 = socket.AF_INET6
unittest.main() unittest.main()
if __name__ == "__main__": if __name__ == "__main__":