Added new methods for DTLSv1.2

* dtls/err.py: Added error code ERR_WRONG_VERSION_NUMBER
* dtls/openssl.py: Added DTLS_server_method(), DTLSv1_2_server_method() and DTLSv1_2_client_method()
* dtls/patch.py: Default protocol DTLS for ssl.wrap_socket() and ssl.SSLSocket()
* dtls/sslconnection.py:
	- Introduced PROTOCOL_DTLSv1_2 and PROTOCOL_DTLS (the latter one is a synonym for the "higher" version)
	- Updated _init_client() and _init_server() with the new protocol methods
	- Default protocol DTLS for SSLConnection()
	- Return on ERR_WRONG_VERSION_NUMBER if client and server cannot agree on protocol version
* dtls/test/unit.py:
	- Extended test_get_server_certificate() to iterate over the different protocol combinations
	- Extended test_protocol_dtlsv1() to try the different protocol combinations between client and server
incoming
mcfreis 2017-03-20 14:48:42 +01:00
parent 8b07f3f46d
commit 83204e8c4d
6 changed files with 225 additions and 141 deletions

View File

@ -1,3 +1,19 @@
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added new methods for DTLSv1.2
* dtls/err.py: Added error code ERR_WRONG_VERSION_NUMBER
* dtls/openssl.py: Added DTLS_server_method(), DTLSv1_2_server_method() and DTLSv1_2_client_method()
* dtls/patch.py: Default protocol DTLS for ssl.wrap_socket() and ssl.SSLSocket()
* dtls/sslconnection.py:
- Introduced PROTOCOL_DTLSv1_2 and PROTOCOL_DTLS (the latter one is a synonym for the "higher" version)
- Updated _init_client() and _init_server() with the new protocol methods
- Default protocol DTLS for SSLConnection()
- Return on ERR_WRONG_VERSION_NUMBER if client and server cannot agree on protocol version
* dtls/test/unit.py:
- Extended test_get_server_certificate() to iterate over the different protocol combinations
- Extended test_protocol_dtlsv1() to try the different protocol combinations between client and server
2017-03-17 Björn Freise <mcfreis@gmx.net> 2017-03-17 Björn Freise <mcfreis@gmx.net>
Updating openSSL libs to v1.0.2l-dev Updating openSSL libs to v1.0.2l-dev

View File

@ -47,6 +47,8 @@ ERR_READ_TIMEOUT = 502
ERR_WRITE_TIMEOUT = 503 ERR_WRITE_TIMEOUT = 503
ERR_HANDSHAKE_TIMEOUT = 504 ERR_HANDSHAKE_TIMEOUT = 504
ERR_PORT_UNREACHABLE = 505 ERR_PORT_UNREACHABLE = 505
ERR_WRONG_VERSION_NUMBER = 0x1408A10B
ERR_COOKIE_MISMATCH = 0x1408A134 ERR_COOKIE_MISMATCH = 0x1408A134

View File

@ -184,9 +184,9 @@ class FuncParam(object):
return self._as_parameter.value return self._as_parameter.value
class DTLSv1Method(FuncParam): class DTLS_Method(FuncParam):
def __init__(self, value): def __init__(self, value):
super(DTLSv1Method, self).__init__(value) super(DTLS_Method, self).__init__(value)
class BIO_METHOD(FuncParam): class BIO_METHOD(FuncParam):
@ -563,12 +563,18 @@ map(lambda x: _make_function(*x), (
((c_void_p, "ret"),), True, None), ((c_void_p, "ret"),), True, None),
("CRYPTO_num_locks", libcrypto, ("CRYPTO_num_locks", libcrypto,
((c_int, "ret"),)), ((c_int, "ret"),)),
("DTLS_server_method", libssl,
((DTLS_Method, "ret"),)),
("DTLSv1_server_method", libssl, ("DTLSv1_server_method", libssl,
((DTLSv1Method, "ret"),)), ((DTLS_Method, "ret"),)),
("DTLSv1_2_server_method", libssl,
((DTLS_Method, "ret"),)),
("DTLSv1_client_method", libssl, ("DTLSv1_client_method", libssl,
((DTLSv1Method, "ret"),)), ((DTLS_Method, "ret"),)),
("DTLSv1_2_client_method", libssl,
((DTLS_Method, "ret"),)),
("SSL_CTX_new", libssl, ("SSL_CTX_new", libssl,
((SSLCTX, "ret"), (DTLSv1Method, "meth"))), ((SSLCTX, "ret"), (DTLS_Method, "meth"))),
("SSL_CTX_free", libssl, ("SSL_CTX_free", libssl,
((None, "ret"), (SSLCTX, "ctx"))), ((None, "ret"), (SSLCTX, "ctx"))),
("SSL_CTX_set_cookie_generate_cb", libssl, ("SSL_CTX_set_cookie_generate_cb", libssl,

View File

@ -36,11 +36,12 @@ has the following effects:
from socket import socket, getaddrinfo, _delegate_methods, error as socket_error from socket import socket, getaddrinfo, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM
from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE
from types import MethodType from types import MethodType
from weakref import proxy from weakref import proxy
import errno import errno
from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error from err import raise_as_ssl_module_error
@ -52,8 +53,14 @@ def do_patch():
ssl = _ssl ssl = _ssl
if hasattr(ssl, "PROTOCOL_DTLSv1"): if hasattr(ssl, "PROTOCOL_DTLSv1"):
return return
_orig_wrap_socket = ssl.wrap_socket
ssl.wrap_socket = _wrap_socket
ssl.PROTOCOL_DTLS = PROTOCOL_DTLS
ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1
ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2
ssl._PROTOCOL_NAMES[PROTOCOL_DTLS] = "DTLS"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2"
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
@ -63,7 +70,18 @@ def do_patch():
ssl.get_server_certificate = _get_server_certificate ssl.get_server_certificate = _get_server_certificate
raise_as_ssl_module_error() raise_as_ssl_module_error()
PROTOCOL_SSLv23 = 2 def _wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None):
return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers)
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
"""Retrieve a server certificate """Retrieve a server certificate
@ -74,7 +92,7 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
If 'ssl_version' is specified, use it in the connection attempt. If 'ssl_version' is specified, use it in the connection attempt.
""" """
if ssl_version != PROTOCOL_DTLSv1: if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2):
return _orig_get_server_certificate(addr, ssl_version, ca_certs) return _orig_get_server_certificate(addr, ssl_version, ca_certs)
if ca_certs is not None: if ca_certs is not None:
@ -92,7 +110,7 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None, def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,

View File

@ -52,7 +52,7 @@ from weakref import proxy
from err import openssl_error, InvalidSocketError from err import openssl_error, InvalidSocketError
from err import raise_ssl_error from err import raise_ssl_error
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_CERTS
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR
@ -64,6 +64,8 @@ from util import _Rsrc, _BIO
_logger = getLogger(__name__) _logger = getLogger(__name__)
PROTOCOL_DTLSv1 = 256 PROTOCOL_DTLSv1 = 256
PROTOCOL_DTLSv1_2 = 258
PROTOCOL_DTLS = 259
CERT_NONE = 0 CERT_NONE = 0
CERT_OPTIONAL = 1 CERT_OPTIONAL = 1
CERT_REQUIRED = 2 CERT_REQUIRED = 2
@ -202,7 +204,12 @@ class SSLConnection(object):
else: else:
self._rsock = rsock self._rsock = rsock
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) server_method = DTLS_server_method
if self._ssl_version == PROTOCOL_DTLSv1_2:
server_method = DTLSv1_2_server_method
elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method()))
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
@ -232,7 +239,12 @@ class SSLConnection(object):
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio self._rbio = self._wbio
self._ctx = _CTX(SSL_CTX_new(DTLSv1_client_method())) client_method = DTLSv1_2_client_method # no "any" exists, therefore use v1_2 (highest possible)
if self._ssl_version == PROTOCOL_DTLSv1_2:
client_method = DTLSv1_2_client_method
elif self._ssl_version == PROTOCOL_DTLSv1:
client_method = DTLSv1_client_method
self._ctx = _CTX(SSL_CTX_new(client_method()))
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
else: else:
@ -364,7 +376,7 @@ class SSLConnection(object):
def __init__(self, sock, keyfile=None, certfile=None, def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLSv1, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None): suppress_ragged_eofs=True, ciphers=None):
"""Constructor """Constructor
@ -486,6 +498,9 @@ class SSLConnection(object):
# This method must be called again to forward the next datagram # This method must be called again to forward the next datagram
_logger.debug("DTLSv1_listen must be resumed") _logger.debug("DTLSv1_listen must be resumed")
return return
elif err.errqueue and err.errqueue[0][0] == ERR_WRONG_VERSION_NUMBER:
_logger.debug("Wrong version number; aborting handshake")
return
elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH: elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH:
_logger.debug("Mismatching cookie received; aborting handshake") _logger.debug("Mismatching cookie received; aborting handshake")
return return

View File

@ -80,6 +80,8 @@ class BasicSocketTests(unittest.TestCase):
ssl.PROTOCOL_SSLv23 ssl.PROTOCOL_SSLv23
ssl.PROTOCOL_TLSv1 ssl.PROTOCOL_TLSv1
ssl.PROTOCOL_DTLSv1 # added ssl.PROTOCOL_DTLSv1 # added
ssl.PROTOCOL_DTLSv1_2 # added
ssl.PROTOCOL_DTLS # added
ssl.CERT_NONE ssl.CERT_NONE
ssl.CERT_OPTIONAL ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED ssl.CERT_REQUIRED
@ -92,8 +94,8 @@ class BasicSocketTests(unittest.TestCase):
self.assertIsInstance(t, tuple) self.assertIsInstance(t, tuple)
self.assertIsInstance(s, str) self.assertIsInstance(s, str)
# Some sanity checks follow # Some sanity checks follow
# >= 1.0 # >= 1.0.2
self.assertGreaterEqual(n, 0x10000000) self.assertGreaterEqual(n, 0x10002000)
# < 2.0 # < 2.0
self.assertLess(n, 0x20000000) self.assertLess(n, 0x20000000)
major, minor, fix, patch, status = t major, minor, fix, patch, status = t
@ -101,7 +103,7 @@ class BasicSocketTests(unittest.TestCase):
self.assertLess(major, 2) self.assertLess(major, 2)
self.assertGreaterEqual(minor, 0) self.assertGreaterEqual(minor, 0)
self.assertLess(minor, 256) self.assertLess(minor, 256)
self.assertGreaterEqual(fix, 0) self.assertGreaterEqual(fix, 2)
self.assertLess(fix, 256) self.assertLess(fix, 256)
self.assertGreaterEqual(patch, 0) self.assertGreaterEqual(patch, 0)
self.assertLessEqual(patch, 26) self.assertLessEqual(patch, 26)
@ -300,23 +302,25 @@ class NetworkedTests(unittest.TestCase):
"to establish session.\n") % count) "to establish session.\n") % count)
def test_get_server_certificate(self): def test_get_server_certificate(self):
for prot in (ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLS):
with test_support.transient_internet() as remote: with test_support.transient_internet() as remote:
pem = ssl.get_server_certificate(remote, ssl.PROTOCOL_DTLSv1) pem = ssl.get_server_certificate(remote,
prot)
if not pem: if not pem:
self.fail("No server certificate!") self.fail("No server certificate!")
try: try:
pem = ssl.get_server_certificate(remote, pem = ssl.get_server_certificate(remote,
ssl.PROTOCOL_DTLSv1, prot,
ca_certs=OTHER_CERTFILE) ca_certs=OTHER_CERTFILE)
except ssl.SSLError: except ssl.SSLError:
#should fail # should fail
pass pass
else: else:
self.fail("Got server certificate %s!" % pem) self.fail("Got server certificate %s!" % pem)
pem = ssl.get_server_certificate(remote, pem = ssl.get_server_certificate(remote,
ssl.PROTOCOL_DTLSv1, prot,
ca_certs=ISSUER_CERTFILE) ca_certs=ISSUER_CERTFILE)
if not pem: if not pem:
self.fail("No server certificate!") self.fail("No server certificate!")
@ -1039,11 +1043,34 @@ class ThreadedTests(unittest.TestCase):
"""Connecting to a DTLSv1 server with various client options""" """Connecting to a DTLSv1 server with various client options"""
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
# server: 1.0 - client: 1.0 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True) try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_OPTIONAL) ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED) ssl.CERT_REQUIRED)
# server: any - client: 1.0 and 1.2(any) -> ok
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True,
ssl.CERT_REQUIRED)
# server: 1.0 - client: 1.2 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.0 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.2 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
def test_starttls(self): def test_starttls(self):
"""Switching from clear text to encrypted and back again.""" """Switching from clear text to encrypted and back again."""
@ -1062,7 +1089,7 @@ class ThreadedTests(unittest.TestCase):
# try to connect # try to connect
wrapped = False wrapped = False
try: try:
s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM)) s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM), ssl_version=ssl.PROTOCOL_DTLSv1)
s.connect((HOST, server.port)) s.connect((HOST, server.port))
s = s.unwrap() s = s.unwrap()
if test_support.verbose: if test_support.verbose: