diff --git a/ChangeLog b/ChangeLog index cdc210e..83e7251 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,19 @@ +2017-03-17 Björn Freise + + Added new methods for DTLSv1.2 + + * dtls/err.py: Added error code ERR_WRONG_VERSION_NUMBER + * dtls/openssl.py: Added DTLS_server_method(), DTLSv1_2_server_method() and DTLSv1_2_client_method() + * dtls/patch.py: Default protocol DTLS for ssl.wrap_socket() and ssl.SSLSocket() + * dtls/sslconnection.py: + - Introduced PROTOCOL_DTLSv1_2 and PROTOCOL_DTLS (the latter one is a synonym for the "higher" version) + - Updated _init_client() and _init_server() with the new protocol methods + - Default protocol DTLS for SSLConnection() + - Return on ERR_WRONG_VERSION_NUMBER if client and server cannot agree on protocol version + * dtls/test/unit.py: + - Extended test_get_server_certificate() to iterate over the different protocol combinations + - Extended test_protocol_dtlsv1() to try the different protocol combinations between client and server + 2017-03-17 Björn Freise Updating openSSL libs to v1.0.2l-dev diff --git a/dtls/err.py b/dtls/err.py index 212554b..4e8f271 100644 --- a/dtls/err.py +++ b/dtls/err.py @@ -44,12 +44,14 @@ ERR_BOTH_KEY_CERT_FILES_SVR = 298 ERR_NO_CERTS = 331 ERR_NO_CIPHER = 501 ERR_READ_TIMEOUT = 502 -ERR_WRITE_TIMEOUT = 503 -ERR_HANDSHAKE_TIMEOUT = 504 -ERR_PORT_UNREACHABLE = 505 -ERR_COOKIE_MISMATCH = 0x1408A134 - - +ERR_WRITE_TIMEOUT = 503 +ERR_HANDSHAKE_TIMEOUT = 504 +ERR_PORT_UNREACHABLE = 505 + +ERR_WRONG_VERSION_NUMBER = 0x1408A10B +ERR_COOKIE_MISMATCH = 0x1408A134 + + class SSLError(socket_error): """This exception is raised by modules in the dtls package.""" def __init__(self, *args): diff --git a/dtls/openssl.py b/dtls/openssl.py index b95138a..d2e872e 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -181,15 +181,15 @@ class FuncParam(object): @property def raw(self): - return self._as_parameter.value - - -class DTLSv1Method(FuncParam): - def __init__(self, value): - super(DTLSv1Method, self).__init__(value) - - -class BIO_METHOD(FuncParam): + return self._as_parameter.value + + +class DTLS_Method(FuncParam): + def __init__(self, value): + super(DTLS_Method, self).__init__(value) + + +class BIO_METHOD(FuncParam): def __init__(self, value): super(BIO_METHOD, self).__init__(value) @@ -563,12 +563,18 @@ map(lambda x: _make_function(*x), ( ((c_void_p, "ret"),), True, None), ("CRYPTO_num_locks", libcrypto, ((c_int, "ret"),)), + ("DTLS_server_method", libssl, + ((DTLS_Method, "ret"),)), ("DTLSv1_server_method", libssl, - ((DTLSv1Method, "ret"),)), + ((DTLS_Method, "ret"),)), + ("DTLSv1_2_server_method", libssl, + ((DTLS_Method, "ret"),)), ("DTLSv1_client_method", libssl, - ((DTLSv1Method, "ret"),)), + ((DTLS_Method, "ret"),)), + ("DTLSv1_2_client_method", libssl, + ((DTLS_Method, "ret"),)), ("SSL_CTX_new", libssl, - ((SSLCTX, "ret"), (DTLSv1Method, "meth"))), + ((SSLCTX, "ret"), (DTLS_Method, "meth"))), ("SSL_CTX_free", libssl, ((None, "ret"), (SSLCTX, "ctx"))), ("SSL_CTX_set_cookie_generate_cb", libssl, diff --git a/dtls/patch.py b/dtls/patch.py index f90132b..44105ae 100644 --- a/dtls/patch.py +++ b/dtls/patch.py @@ -36,11 +36,12 @@ has the following effects: from socket import socket, getaddrinfo, _delegate_methods, error as socket_error from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM +from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE from types import MethodType from weakref import proxy import errno -from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE +from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2 from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO from err import raise_as_ssl_module_error @@ -49,32 +50,49 @@ def do_patch(): import ssl as _ssl # import to be avoided if ssl module is never patched global _orig_SSLSocket_init, _orig_get_server_certificate global ssl - ssl = _ssl - if hasattr(ssl, "PROTOCOL_DTLSv1"): - return - ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 - ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" - ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER - ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION - ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO + ssl = _ssl + if hasattr(ssl, "PROTOCOL_DTLSv1"): + return + _orig_wrap_socket = ssl.wrap_socket + ssl.wrap_socket = _wrap_socket + ssl.PROTOCOL_DTLS = PROTOCOL_DTLS + ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 + ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2 + ssl._PROTOCOL_NAMES[PROTOCOL_DTLS] = "DTLS" + ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" + ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2" + ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER + ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION + ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO _orig_SSLSocket_init = ssl.SSLSocket.__init__ _orig_get_server_certificate = ssl.get_server_certificate ssl.SSLSocket.__init__ = _SSLSocket_init - ssl.get_server_certificate = _get_server_certificate - raise_as_ssl_module_error() - -PROTOCOL_SSLv23 = 2 - -def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): - """Retrieve a server certificate + ssl.get_server_certificate = _get_server_certificate + raise_as_ssl_module_error() + +def _wrap_socket(sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_DTLS, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, ciphers=None): + + return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile, + server_side=server_side, cert_reqs=cert_reqs, + ssl_version=ssl_version, ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=ciphers) + +def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): + """Retrieve a server certificate Retrieve the certificate from the server at the specified address, and return it as a PEM-encoded string. If 'ca_certs' is specified, validate the server cert against it. - If 'ssl_version' is specified, use it in the connection attempt. - """ - - if ssl_version != PROTOCOL_DTLSv1: + If 'ssl_version' is specified, use it in the connection attempt. + """ + + if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2): return _orig_get_server_certificate(addr, ssl_version, ca_certs) if ca_certs is not None: @@ -91,9 +109,9 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): return ssl.DER_cert_to_PEM_cert(dercert) def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None, - server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_SSLv23, ca_certs=None, - do_handshake_on_connect=True, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_DTLS, ca_certs=None, + do_handshake_on_connect=True, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, server_hostname=None, diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index dde0b6b..2be2e0f 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -52,7 +52,7 @@ from weakref import proxy from err import openssl_error, InvalidSocketError from err import raise_ssl_error from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL -from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS +from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR @@ -61,12 +61,14 @@ from tlock import tlock_init from openssl import * from util import _Rsrc, _BIO -_logger = getLogger(__name__) - -PROTOCOL_DTLSv1 = 256 -CERT_NONE = 0 -CERT_OPTIONAL = 1 -CERT_REQUIRED = 2 +_logger = getLogger(__name__) + +PROTOCOL_DTLSv1 = 256 +PROTOCOL_DTLSv1_2 = 258 +PROTOCOL_DTLS = 259 +CERT_NONE = 0 +CERT_OPTIONAL = 1 +CERT_REQUIRED = 2 # # One-time global OpenSSL library initialization @@ -199,13 +201,18 @@ class SSLConnection(object): rsock = self._udp_demux.get_connection(None) if rsock is self._sock: self._rbio = self._wbio - else: - self._rsock = rsock - self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) - self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) - SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) - if self._cert_reqs == CERT_NONE: - verify_mode = SSL_VERIFY_NONE + else: + self._rsock = rsock + self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) + server_method = DTLS_server_method + if self._ssl_version == PROTOCOL_DTLSv1_2: + server_method = DTLSv1_2_server_method + elif self._ssl_version == PROTOCOL_DTLSv1: + server_method = DTLSv1_server_method + self._ctx = _CTX(SSL_CTX_new(server_method())) + SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) + if self._cert_reqs == CERT_NONE: + verify_mode = SSL_VERIFY_NONE elif self._cert_reqs == CERT_OPTIONAL: verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE else: @@ -229,13 +236,18 @@ class SSLConnection(object): def _init_client(self, peer_address): if self._sock.type != socket.SOCK_DGRAM: raise InvalidSocketError("sock must be of type SOCK_DGRAM") - - self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) - self._rbio = self._wbio - self._ctx = _CTX(SSL_CTX_new(DTLSv1_client_method())) - if self._cert_reqs == CERT_NONE: - verify_mode = SSL_VERIFY_NONE - else: + + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = self._wbio + client_method = DTLSv1_2_client_method # no "any" exists, therefore use v1_2 (highest possible) + if self._ssl_version == PROTOCOL_DTLSv1_2: + client_method = DTLSv1_2_client_method + elif self._ssl_version == PROTOCOL_DTLSv1: + client_method = DTLSv1_client_method + self._ctx = _CTX(SSL_CTX_new(client_method())) + if self._cert_reqs == CERT_NONE: + verify_mode = SSL_VERIFY_NONE + else: verify_mode = SSL_VERIFY_PEER self._config_ssl_ctx(verify_mode) self._ssl = _SSL(SSL_new(self._ctx.value)) @@ -361,13 +373,13 @@ class SSLConnection(object): def _verify_cookie_cb(self, ssl, cookie): if self._get_cookie(ssl) != cookie: raise Exception("DTLS cookie mismatch") - - def __init__(self, sock, keyfile=None, certfile=None, - server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_DTLSv1, ca_certs=None, - do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None): - """Constructor + + def __init__(self, sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_DTLS, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, ciphers=None): + """Constructor Arguments: these arguments match the ones of the SSLSocket class in the @@ -483,12 +495,15 @@ class SSLConnection(object): dtls_peer_address = DTLSv1_listen(self._ssl.value) except openssl_error() as err: if err.ssl_error == SSL_ERROR_WANT_READ: - # This method must be called again to forward the next datagram - _logger.debug("DTLSv1_listen must be resumed") - return - elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH: - _logger.debug("Mismatching cookie received; aborting handshake") - return + # This method must be called again to forward the next datagram + _logger.debug("DTLSv1_listen must be resumed") + return + elif err.errqueue and err.errqueue[0][0] == ERR_WRONG_VERSION_NUMBER: + _logger.debug("Wrong version number; aborting handshake") + return + elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH: + _logger.debug("Mismatching cookie received; aborting handshake") + return _logger.exception("Unexpected error in DTLSv1_listen") raise finally: diff --git a/dtls/test/unit.py b/dtls/test/unit.py index b7e28a5..c2a521d 100644 --- a/dtls/test/unit.py +++ b/dtls/test/unit.py @@ -78,10 +78,12 @@ class BasicSocketTests(unittest.TestCase): def test_constants(self): ssl.PROTOCOL_SSLv23 - ssl.PROTOCOL_TLSv1 - ssl.PROTOCOL_DTLSv1 # added - ssl.CERT_NONE - ssl.CERT_OPTIONAL + ssl.PROTOCOL_TLSv1 + ssl.PROTOCOL_DTLSv1 # added + ssl.PROTOCOL_DTLSv1_2 # added + ssl.PROTOCOL_DTLS # added + ssl.CERT_NONE + ssl.CERT_OPTIONAL ssl.CERT_REQUIRED def test_dtls_openssl_version(self): @@ -89,22 +91,22 @@ class BasicSocketTests(unittest.TestCase): t = ssl.DTLS_OPENSSL_VERSION_INFO s = ssl.DTLS_OPENSSL_VERSION self.assertIsInstance(n, (int, long)) - self.assertIsInstance(t, tuple) - self.assertIsInstance(s, str) - # Some sanity checks follow - # >= 1.0 - self.assertGreaterEqual(n, 0x10000000) - # < 2.0 - self.assertLess(n, 0x20000000) - major, minor, fix, patch, status = t + self.assertIsInstance(t, tuple) + self.assertIsInstance(s, str) + # Some sanity checks follow + # >= 1.0.2 + self.assertGreaterEqual(n, 0x10002000) + # < 2.0 + self.assertLess(n, 0x20000000) + major, minor, fix, patch, status = t self.assertGreaterEqual(major, 1) - self.assertLess(major, 2) - self.assertGreaterEqual(minor, 0) - self.assertLess(minor, 256) - self.assertGreaterEqual(fix, 0) - self.assertLess(fix, 256) - self.assertGreaterEqual(patch, 0) - self.assertLessEqual(patch, 26) + self.assertLess(major, 2) + self.assertGreaterEqual(minor, 0) + self.assertLess(minor, 256) + self.assertGreaterEqual(fix, 2) + self.assertLess(fix, 256) + self.assertGreaterEqual(patch, 0) + self.assertLessEqual(patch, 26) self.assertGreaterEqual(status, 0) self.assertLessEqual(status, 15) # Version string as returned by OpenSSL, the format might change @@ -297,35 +299,37 @@ class NetworkedTests(unittest.TestCase): s.close() if test_support.verbose: sys.stdout.write(("\nNeeded %d calls to do_handshake() " + - "to establish session.\n") % count) - - def test_get_server_certificate(self): - with test_support.transient_internet() as remote: - pem = ssl.get_server_certificate(remote, ssl.PROTOCOL_DTLSv1) - if not pem: - self.fail("No server certificate!") - - try: - pem = ssl.get_server_certificate(remote, - ssl.PROTOCOL_DTLSv1, - ca_certs=OTHER_CERTFILE) - except ssl.SSLError: - #should fail - pass - else: - self.fail("Got server certificate %s!" % pem) - - pem = ssl.get_server_certificate(remote, - ssl.PROTOCOL_DTLSv1, - ca_certs=ISSUER_CERTFILE) - if not pem: - self.fail("No server certificate!") - if test_support.verbose: - sys.stdout.write("\nVerified certificate is\n%s\n" % pem) - -class ThreadedEchoServer(threading.Thread): - - class ConnectionHandler(threading.Thread): + "to establish session.\n") % count) + + def test_get_server_certificate(self): + for prot in (ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLS): + with test_support.transient_internet() as remote: + pem = ssl.get_server_certificate(remote, + prot) + if not pem: + self.fail("No server certificate!") + + try: + pem = ssl.get_server_certificate(remote, + prot, + ca_certs=OTHER_CERTFILE) + except ssl.SSLError: + # should fail + pass + else: + self.fail("Got server certificate %s!" % pem) + + pem = ssl.get_server_certificate(remote, + prot, + ca_certs=ISSUER_CERTFILE) + if not pem: + self.fail("No server certificate!") + if test_support.verbose: + sys.stdout.write("\nVerified certificate is\n%s\n" % pem) + +class ThreadedEchoServer(threading.Thread): + + class ConnectionHandler(threading.Thread): """A mildly complicated class, because we want it to work both with and without the SSL wrapper around the socket connection, so @@ -1036,17 +1040,40 @@ class ThreadedTests(unittest.TestCase): "certs", "badkey.pem")) def test_protocol_dtlsv1(self): - """Connecting to a DTLSv1 server with various client options""" - if test_support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True) - try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, - ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, - ssl.CERT_REQUIRED) - - def test_starttls(self): - """Switching from clear text to encrypted and back again.""" + """Connecting to a DTLSv1 server with various client options""" + if test_support.verbose: + sys.stdout.write("\n") + # server: 1.0 - client: 1.0 -> ok + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True) + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, + ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, + ssl.CERT_REQUIRED) + # server: any - client: 1.0 and 1.2(any) -> ok + try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True) + try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True, + ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True) + try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True, + ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True) + try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True, + ssl.CERT_REQUIRED) + # server: 1.0 - client: 1.2 -> fail + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False) + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False, + ssl.CERT_REQUIRED) + # server: 1.2 - client: 1.0 -> fail + try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False) + try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False, + ssl.CERT_REQUIRED) + # server: 1.2 - client: 1.2 -> ok + try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True) + try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True, + ssl.CERT_REQUIRED) + + def test_starttls(self): + """Switching from clear text to encrypted and back again.""" msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6") @@ -1059,13 +1086,13 @@ class ThreadedTests(unittest.TestCase): server.start(flag) # wait for it to start flag.wait() - # try to connect - wrapped = False - try: - s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM)) - s.connect((HOST, server.port)) - s = s.unwrap() - if test_support.verbose: + # try to connect + wrapped = False + try: + s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM), ssl_version=ssl.PROTOCOL_DTLSv1) + s.connect((HOST, server.port)) + s = s.unwrap() + if test_support.verbose: sys.stdout.write("\n") for indata in msgs: if test_support.verbose: