Added new methods for DTLSv1.2

* dtls/err.py: Added error code ERR_WRONG_VERSION_NUMBER
* dtls/openssl.py: Added DTLS_server_method(), DTLSv1_2_server_method() and DTLSv1_2_client_method()
* dtls/patch.py: Default protocol DTLS for ssl.wrap_socket() and ssl.SSLSocket()
* dtls/sslconnection.py:
	- Introduced PROTOCOL_DTLSv1_2 and PROTOCOL_DTLS (the latter one is a synonym for the "higher" version)
	- Updated _init_client() and _init_server() with the new protocol methods
	- Default protocol DTLS for SSLConnection()
	- Return on ERR_WRONG_VERSION_NUMBER if client and server cannot agree on protocol version
* dtls/test/unit.py:
	- Extended test_get_server_certificate() to iterate over the different protocol combinations
	- Extended test_protocol_dtlsv1() to try the different protocol combinations between client and server
incoming
mcfreis 2017-03-20 14:48:42 +01:00
parent 8b07f3f46d
commit 83204e8c4d
6 changed files with 225 additions and 141 deletions

View File

@ -1,3 +1,19 @@
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added new methods for DTLSv1.2
* dtls/err.py: Added error code ERR_WRONG_VERSION_NUMBER
* dtls/openssl.py: Added DTLS_server_method(), DTLSv1_2_server_method() and DTLSv1_2_client_method()
* dtls/patch.py: Default protocol DTLS for ssl.wrap_socket() and ssl.SSLSocket()
* dtls/sslconnection.py:
- Introduced PROTOCOL_DTLSv1_2 and PROTOCOL_DTLS (the latter one is a synonym for the "higher" version)
- Updated _init_client() and _init_server() with the new protocol methods
- Default protocol DTLS for SSLConnection()
- Return on ERR_WRONG_VERSION_NUMBER if client and server cannot agree on protocol version
* dtls/test/unit.py:
- Extended test_get_server_certificate() to iterate over the different protocol combinations
- Extended test_protocol_dtlsv1() to try the different protocol combinations between client and server
2017-03-17 Björn Freise <mcfreis@gmx.net> 2017-03-17 Björn Freise <mcfreis@gmx.net>
Updating openSSL libs to v1.0.2l-dev Updating openSSL libs to v1.0.2l-dev

View File

@ -44,12 +44,14 @@ ERR_BOTH_KEY_CERT_FILES_SVR = 298
ERR_NO_CERTS = 331 ERR_NO_CERTS = 331
ERR_NO_CIPHER = 501 ERR_NO_CIPHER = 501
ERR_READ_TIMEOUT = 502 ERR_READ_TIMEOUT = 502
ERR_WRITE_TIMEOUT = 503 ERR_WRITE_TIMEOUT = 503
ERR_HANDSHAKE_TIMEOUT = 504 ERR_HANDSHAKE_TIMEOUT = 504
ERR_PORT_UNREACHABLE = 505 ERR_PORT_UNREACHABLE = 505
ERR_COOKIE_MISMATCH = 0x1408A134
ERR_WRONG_VERSION_NUMBER = 0x1408A10B
ERR_COOKIE_MISMATCH = 0x1408A134
class SSLError(socket_error): class SSLError(socket_error):
"""This exception is raised by modules in the dtls package.""" """This exception is raised by modules in the dtls package."""
def __init__(self, *args): def __init__(self, *args):

View File

@ -181,15 +181,15 @@ class FuncParam(object):
@property @property
def raw(self): def raw(self):
return self._as_parameter.value return self._as_parameter.value
class DTLSv1Method(FuncParam): class DTLS_Method(FuncParam):
def __init__(self, value): def __init__(self, value):
super(DTLSv1Method, self).__init__(value) super(DTLS_Method, self).__init__(value)
class BIO_METHOD(FuncParam): class BIO_METHOD(FuncParam):
def __init__(self, value): def __init__(self, value):
super(BIO_METHOD, self).__init__(value) super(BIO_METHOD, self).__init__(value)
@ -563,12 +563,18 @@ map(lambda x: _make_function(*x), (
((c_void_p, "ret"),), True, None), ((c_void_p, "ret"),), True, None),
("CRYPTO_num_locks", libcrypto, ("CRYPTO_num_locks", libcrypto,
((c_int, "ret"),)), ((c_int, "ret"),)),
("DTLS_server_method", libssl,
((DTLS_Method, "ret"),)),
("DTLSv1_server_method", libssl, ("DTLSv1_server_method", libssl,
((DTLSv1Method, "ret"),)), ((DTLS_Method, "ret"),)),
("DTLSv1_2_server_method", libssl,
((DTLS_Method, "ret"),)),
("DTLSv1_client_method", libssl, ("DTLSv1_client_method", libssl,
((DTLSv1Method, "ret"),)), ((DTLS_Method, "ret"),)),
("DTLSv1_2_client_method", libssl,
((DTLS_Method, "ret"),)),
("SSL_CTX_new", libssl, ("SSL_CTX_new", libssl,
((SSLCTX, "ret"), (DTLSv1Method, "meth"))), ((SSLCTX, "ret"), (DTLS_Method, "meth"))),
("SSL_CTX_free", libssl, ("SSL_CTX_free", libssl,
((None, "ret"), (SSLCTX, "ctx"))), ((None, "ret"), (SSLCTX, "ctx"))),
("SSL_CTX_set_cookie_generate_cb", libssl, ("SSL_CTX_set_cookie_generate_cb", libssl,

View File

@ -36,11 +36,12 @@ has the following effects:
from socket import socket, getaddrinfo, _delegate_methods, error as socket_error from socket import socket, getaddrinfo, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM
from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE
from types import MethodType from types import MethodType
from weakref import proxy from weakref import proxy
import errno import errno
from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error from err import raise_as_ssl_module_error
@ -49,32 +50,49 @@ def do_patch():
import ssl as _ssl # import to be avoided if ssl module is never patched import ssl as _ssl # import to be avoided if ssl module is never patched
global _orig_SSLSocket_init, _orig_get_server_certificate global _orig_SSLSocket_init, _orig_get_server_certificate
global ssl global ssl
ssl = _ssl ssl = _ssl
if hasattr(ssl, "PROTOCOL_DTLSv1"): if hasattr(ssl, "PROTOCOL_DTLSv1"):
return return
ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 _orig_wrap_socket = ssl.wrap_socket
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" ssl.wrap_socket = _wrap_socket
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER ssl.PROTOCOL_DTLS = PROTOCOL_DTLS
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2
ssl._PROTOCOL_NAMES[PROTOCOL_DTLS] = "DTLS"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2"
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
_orig_SSLSocket_init = ssl.SSLSocket.__init__ _orig_SSLSocket_init = ssl.SSLSocket.__init__
_orig_get_server_certificate = ssl.get_server_certificate _orig_get_server_certificate = ssl.get_server_certificate
ssl.SSLSocket.__init__ = _SSLSocket_init ssl.SSLSocket.__init__ = _SSLSocket_init
ssl.get_server_certificate = _get_server_certificate ssl.get_server_certificate = _get_server_certificate
raise_as_ssl_module_error() raise_as_ssl_module_error()
PROTOCOL_SSLv23 = 2 def _wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): ssl_version=PROTOCOL_DTLS, ca_certs=None,
"""Retrieve a server certificate do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None):
return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers)
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
"""Retrieve a server certificate
Retrieve the certificate from the server at the specified address, Retrieve the certificate from the server at the specified address,
and return it as a PEM-encoded string. and return it as a PEM-encoded string.
If 'ca_certs' is specified, validate the server cert against it. If 'ca_certs' is specified, validate the server cert against it.
If 'ssl_version' is specified, use it in the connection attempt. If 'ssl_version' is specified, use it in the connection attempt.
""" """
if ssl_version != PROTOCOL_DTLSv1: if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2):
return _orig_get_server_certificate(addr, ssl_version, ca_certs) return _orig_get_server_certificate(addr, ssl_version, ca_certs)
if ca_certs is not None: if ca_certs is not None:
@ -91,9 +109,9 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
return ssl.DER_cert_to_PEM_cert(dercert) return ssl.DER_cert_to_PEM_cert(dercert)
def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None, def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
server_hostname=None, server_hostname=None,

View File

@ -52,7 +52,7 @@ from weakref import proxy
from err import openssl_error, InvalidSocketError from err import openssl_error, InvalidSocketError
from err import raise_ssl_error from err import raise_ssl_error
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_CERTS
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR
@ -61,12 +61,14 @@ from tlock import tlock_init
from openssl import * from openssl import *
from util import _Rsrc, _BIO from util import _Rsrc, _BIO
_logger = getLogger(__name__) _logger = getLogger(__name__)
PROTOCOL_DTLSv1 = 256 PROTOCOL_DTLSv1 = 256
CERT_NONE = 0 PROTOCOL_DTLSv1_2 = 258
CERT_OPTIONAL = 1 PROTOCOL_DTLS = 259
CERT_REQUIRED = 2 CERT_NONE = 0
CERT_OPTIONAL = 1
CERT_REQUIRED = 2
# #
# One-time global OpenSSL library initialization # One-time global OpenSSL library initialization
@ -199,13 +201,18 @@ class SSLConnection(object):
rsock = self._udp_demux.get_connection(None) rsock = self._udp_demux.get_connection(None)
if rsock is self._sock: if rsock is self._sock:
self._rbio = self._wbio self._rbio = self._wbio
else: else:
self._rsock = rsock self._rsock = rsock
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) server_method = DTLS_server_method
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) if self._ssl_version == PROTOCOL_DTLSv1_2:
if self._cert_reqs == CERT_NONE: server_method = DTLSv1_2_server_method
verify_mode = SSL_VERIFY_NONE elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method()))
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE
elif self._cert_reqs == CERT_OPTIONAL: elif self._cert_reqs == CERT_OPTIONAL:
verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE
else: else:
@ -229,13 +236,18 @@ class SSLConnection(object):
def _init_client(self, peer_address): def _init_client(self, peer_address):
if self._sock.type != socket.SOCK_DGRAM: if self._sock.type != socket.SOCK_DGRAM:
raise InvalidSocketError("sock must be of type SOCK_DGRAM") raise InvalidSocketError("sock must be of type SOCK_DGRAM")
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio self._rbio = self._wbio
self._ctx = _CTX(SSL_CTX_new(DTLSv1_client_method())) client_method = DTLSv1_2_client_method # no "any" exists, therefore use v1_2 (highest possible)
if self._cert_reqs == CERT_NONE: if self._ssl_version == PROTOCOL_DTLSv1_2:
verify_mode = SSL_VERIFY_NONE client_method = DTLSv1_2_client_method
else: elif self._ssl_version == PROTOCOL_DTLSv1:
client_method = DTLSv1_client_method
self._ctx = _CTX(SSL_CTX_new(client_method()))
if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE
else:
verify_mode = SSL_VERIFY_PEER verify_mode = SSL_VERIFY_PEER
self._config_ssl_ctx(verify_mode) self._config_ssl_ctx(verify_mode)
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
@ -361,13 +373,13 @@ class SSLConnection(object):
def _verify_cookie_cb(self, ssl, cookie): def _verify_cookie_cb(self, ssl, cookie):
if self._get_cookie(ssl) != cookie: if self._get_cookie(ssl) != cookie:
raise Exception("DTLS cookie mismatch") raise Exception("DTLS cookie mismatch")
def __init__(self, sock, keyfile=None, certfile=None, def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLSv1, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None): suppress_ragged_eofs=True, ciphers=None):
"""Constructor """Constructor
Arguments: Arguments:
these arguments match the ones of the SSLSocket class in the these arguments match the ones of the SSLSocket class in the
@ -483,12 +495,15 @@ class SSLConnection(object):
dtls_peer_address = DTLSv1_listen(self._ssl.value) dtls_peer_address = DTLSv1_listen(self._ssl.value)
except openssl_error() as err: except openssl_error() as err:
if err.ssl_error == SSL_ERROR_WANT_READ: if err.ssl_error == SSL_ERROR_WANT_READ:
# This method must be called again to forward the next datagram # This method must be called again to forward the next datagram
_logger.debug("DTLSv1_listen must be resumed") _logger.debug("DTLSv1_listen must be resumed")
return return
elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH: elif err.errqueue and err.errqueue[0][0] == ERR_WRONG_VERSION_NUMBER:
_logger.debug("Mismatching cookie received; aborting handshake") _logger.debug("Wrong version number; aborting handshake")
return return
elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH:
_logger.debug("Mismatching cookie received; aborting handshake")
return
_logger.exception("Unexpected error in DTLSv1_listen") _logger.exception("Unexpected error in DTLSv1_listen")
raise raise
finally: finally:

View File

@ -78,10 +78,12 @@ class BasicSocketTests(unittest.TestCase):
def test_constants(self): def test_constants(self):
ssl.PROTOCOL_SSLv23 ssl.PROTOCOL_SSLv23
ssl.PROTOCOL_TLSv1 ssl.PROTOCOL_TLSv1
ssl.PROTOCOL_DTLSv1 # added ssl.PROTOCOL_DTLSv1 # added
ssl.CERT_NONE ssl.PROTOCOL_DTLSv1_2 # added
ssl.CERT_OPTIONAL ssl.PROTOCOL_DTLS # added
ssl.CERT_NONE
ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED ssl.CERT_REQUIRED
def test_dtls_openssl_version(self): def test_dtls_openssl_version(self):
@ -89,22 +91,22 @@ class BasicSocketTests(unittest.TestCase):
t = ssl.DTLS_OPENSSL_VERSION_INFO t = ssl.DTLS_OPENSSL_VERSION_INFO
s = ssl.DTLS_OPENSSL_VERSION s = ssl.DTLS_OPENSSL_VERSION
self.assertIsInstance(n, (int, long)) self.assertIsInstance(n, (int, long))
self.assertIsInstance(t, tuple) self.assertIsInstance(t, tuple)
self.assertIsInstance(s, str) self.assertIsInstance(s, str)
# Some sanity checks follow # Some sanity checks follow
# >= 1.0 # >= 1.0.2
self.assertGreaterEqual(n, 0x10000000) self.assertGreaterEqual(n, 0x10002000)
# < 2.0 # < 2.0
self.assertLess(n, 0x20000000) self.assertLess(n, 0x20000000)
major, minor, fix, patch, status = t major, minor, fix, patch, status = t
self.assertGreaterEqual(major, 1) self.assertGreaterEqual(major, 1)
self.assertLess(major, 2) self.assertLess(major, 2)
self.assertGreaterEqual(minor, 0) self.assertGreaterEqual(minor, 0)
self.assertLess(minor, 256) self.assertLess(minor, 256)
self.assertGreaterEqual(fix, 0) self.assertGreaterEqual(fix, 2)
self.assertLess(fix, 256) self.assertLess(fix, 256)
self.assertGreaterEqual(patch, 0) self.assertGreaterEqual(patch, 0)
self.assertLessEqual(patch, 26) self.assertLessEqual(patch, 26)
self.assertGreaterEqual(status, 0) self.assertGreaterEqual(status, 0)
self.assertLessEqual(status, 15) self.assertLessEqual(status, 15)
# Version string as returned by OpenSSL, the format might change # Version string as returned by OpenSSL, the format might change
@ -297,35 +299,37 @@ class NetworkedTests(unittest.TestCase):
s.close() s.close()
if test_support.verbose: if test_support.verbose:
sys.stdout.write(("\nNeeded %d calls to do_handshake() " + sys.stdout.write(("\nNeeded %d calls to do_handshake() " +
"to establish session.\n") % count) "to establish session.\n") % count)
def test_get_server_certificate(self): def test_get_server_certificate(self):
with test_support.transient_internet() as remote: for prot in (ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLS):
pem = ssl.get_server_certificate(remote, ssl.PROTOCOL_DTLSv1) with test_support.transient_internet() as remote:
if not pem: pem = ssl.get_server_certificate(remote,
self.fail("No server certificate!") prot)
if not pem:
try: self.fail("No server certificate!")
pem = ssl.get_server_certificate(remote,
ssl.PROTOCOL_DTLSv1, try:
ca_certs=OTHER_CERTFILE) pem = ssl.get_server_certificate(remote,
except ssl.SSLError: prot,
#should fail ca_certs=OTHER_CERTFILE)
pass except ssl.SSLError:
else: # should fail
self.fail("Got server certificate %s!" % pem) pass
else:
pem = ssl.get_server_certificate(remote, self.fail("Got server certificate %s!" % pem)
ssl.PROTOCOL_DTLSv1,
ca_certs=ISSUER_CERTFILE) pem = ssl.get_server_certificate(remote,
if not pem: prot,
self.fail("No server certificate!") ca_certs=ISSUER_CERTFILE)
if test_support.verbose: if not pem:
sys.stdout.write("\nVerified certificate is\n%s\n" % pem) self.fail("No server certificate!")
if test_support.verbose:
class ThreadedEchoServer(threading.Thread): sys.stdout.write("\nVerified certificate is\n%s\n" % pem)
class ConnectionHandler(threading.Thread): class ThreadedEchoServer(threading.Thread):
class ConnectionHandler(threading.Thread):
"""A mildly complicated class, because we want it to work both """A mildly complicated class, because we want it to work both
with and without the SSL wrapper around the socket connection, so with and without the SSL wrapper around the socket connection, so
@ -1036,17 +1040,40 @@ class ThreadedTests(unittest.TestCase):
"certs", "badkey.pem")) "certs", "badkey.pem"))
def test_protocol_dtlsv1(self): def test_protocol_dtlsv1(self):
"""Connecting to a DTLSv1 server with various client options""" """Connecting to a DTLSv1 server with various client options"""
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True) # server: 1.0 - client: 1.0 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True)
ssl.CERT_OPTIONAL) try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, ssl.CERT_OPTIONAL)
ssl.CERT_REQUIRED) try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED)
def test_starttls(self): # server: any - client: 1.0 and 1.2(any) -> ok
"""Switching from clear text to encrypted and back again.""" try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True,
ssl.CERT_REQUIRED)
# server: 1.0 - client: 1.2 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.0 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.2 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
def test_starttls(self):
"""Switching from clear text to encrypted and back again."""
msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS",
"msg 5", "msg 6") "msg 5", "msg 6")
@ -1059,13 +1086,13 @@ class ThreadedTests(unittest.TestCase):
server.start(flag) server.start(flag)
# wait for it to start # wait for it to start
flag.wait() flag.wait()
# try to connect # try to connect
wrapped = False wrapped = False
try: try:
s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM)) s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM), ssl_version=ssl.PROTOCOL_DTLSv1)
s.connect((HOST, server.port)) s.connect((HOST, server.port))
s = s.unwrap() s = s.unwrap()
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
for indata in msgs: for indata in msgs:
if test_support.verbose: if test_support.verbose: