From ff509e07248319da605a1ca2e269f9a75331ae02 Mon Sep 17 00:00:00 2001 From: mcfreis Date: Mon, 20 Mar 2017 16:39:50 +0100 Subject: [PATCH] Added more on error evaluation and a method to get the peer certificate chain * dtls/__init__.py: import error codes from err.py as error_codes for external access * dtls/err.py: Added errors for ERR_WRONG_SSL_VERSION, ERR_CERTIFICATE_VERIFY_FAILED, ERR_NO_SHARED_CIPHER and ERR_SSL_HANDSHAKE_FAILURE * dtls/openssl.py: - Added constant SSL_BUILD_CHAIN_FLAG_NONE for SSL_CTX_build_cert_chain() - Added method SSL_get_peer_cert_chain() * dtls/patch.py: Added getpeercertchain() as method to ssl.SSLSocket() * dtls/sslconnection.py: - Bugfix SSLContext.set_ecdh_curve() returns 1 for success and 0 for failure - SSLContext.build_cert_chain() changed default flags to SSL_BUILD_CHAIN_FLAG_NONE - In SSLConnection() the mtu size gets only set if no user config function is given - SSLConnection.listen() raises an exception for ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_SHARED_CIPHER and all other unknown errors - SSLConnection.read() and write() now can also raise ERR_PORT_UNREACHABLE - If SSLConnection.write() successfully writes bytes to the peer, then the handshake is assumed to be okay - Added method SSLConnection.getpeercertchain() * dtls/test/unit.py: ThreadedEchoServer() with an extra exception branch for the newly raised exceptions in SSLConnection.listen() --- ChangeLog | 20 ++++++ dtls/__init__.py | 17 ++--- dtls/err.py | 16 +++-- dtls/openssl.py | 67 ++++++++++++------- dtls/patch.py | 30 +++++---- dtls/sslconnection.py | 152 +++++++++++++++++++++++++----------------- dtls/test/unit.py | 26 ++++---- 7 files changed, 207 insertions(+), 121 deletions(-) diff --git a/ChangeLog b/ChangeLog index 549f76f..69577c4 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,23 @@ +2017-03-17 Björn Freise + + Added more on error evaluation and a method to get the peer certificate chain + + * dtls/__init__.py: import error codes from err.py as error_codes for external access + * dtls/err.py: Added errors for ERR_WRONG_SSL_VERSION, ERR_CERTIFICATE_VERIFY_FAILED, ERR_NO_SHARED_CIPHER and ERR_SSL_HANDSHAKE_FAILURE + * dtls/openssl.py: + - Added constant SSL_BUILD_CHAIN_FLAG_NONE for SSL_CTX_build_cert_chain() + - Added method SSL_get_peer_cert_chain() + * dtls/patch.py: Added getpeercertchain() as method to ssl.SSLSocket() + * dtls/sslconnection.py: + - Bugfix SSLContext.set_ecdh_curve() returns 1 for success and 0 for failure + - SSLContext.build_cert_chain() changed default flags to SSL_BUILD_CHAIN_FLAG_NONE + - In SSLConnection() the mtu size gets only set if no user config function is given + - SSLConnection.listen() raises an exception for ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_SHARED_CIPHER and all other unknown errors + - SSLConnection.read() and write() now can also raise ERR_PORT_UNREACHABLE + - If SSLConnection.write() successfully writes bytes to the peer, then the handshake is assumed to be okay + - Added method SSLConnection.getpeercertchain() + * dtls/test/unit.py: ThreadedEchoServer() with an extra exception branch for the newly raised exceptions in SSLConnection.listen() + 2017-03-17 Björn Freise Added certificate creation using ECDSA diff --git a/dtls/__init__.py b/dtls/__init__.py index 7912b1e..b5c5517 100644 --- a/dtls/__init__.py +++ b/dtls/__init__.py @@ -53,11 +53,12 @@ def _prep_bins(): for prebuilt_file in files: try: copy(path.join(prebuilt_path, prebuilt_file), package_root) - except IOError: - pass - -_prep_bins() # prepare before module imports - -from patch import do_patch -from sslconnection import SSLConnection -from demux import force_routing_demux, reset_default_demux + except IOError: + pass + +_prep_bins() # prepare before module imports + +from patch import do_patch +from sslconnection import SSLConnection +from demux import force_routing_demux, reset_default_demux +import err as error_codes diff --git a/dtls/err.py b/dtls/err.py index 4e8f271..b869288 100644 --- a/dtls/err.py +++ b/dtls/err.py @@ -42,20 +42,24 @@ SSL_ERROR_WANT_ACCEPT = 8 ERR_BOTH_KEY_CERT_FILES = 500 ERR_BOTH_KEY_CERT_FILES_SVR = 298 ERR_NO_CERTS = 331 -ERR_NO_CIPHER = 501 -ERR_READ_TIMEOUT = 502 +ERR_NO_CIPHER = 501 +ERR_READ_TIMEOUT = 502 ERR_WRITE_TIMEOUT = 503 ERR_HANDSHAKE_TIMEOUT = 504 ERR_PORT_UNREACHABLE = 505 +ERR_WRONG_SSL_VERSION = 0x1409210A ERR_WRONG_VERSION_NUMBER = 0x1408A10B ERR_COOKIE_MISMATCH = 0x1408A134 +ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086 +ERR_NO_SHARED_CIPHER = 0x1408A0C1 +ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5 -class SSLError(socket_error): - """This exception is raised by modules in the dtls package.""" - def __init__(self, *args): - super(SSLError, self).__init__(*args) +class SSLError(socket_error): + """This exception is raised by modules in the dtls package.""" + def __init__(self, *args): + super(SSLError, self).__init__(*args) class InvalidSocketError(Exception): diff --git a/dtls/openssl.py b/dtls/openssl.py index 8b3a9b8..d3e3140 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -92,12 +92,13 @@ SSL_VERIFY_CLIENT_ONCE = 0x04 SSL_SESS_CACHE_OFF = 0x0000 SSL_SESS_CACHE_CLIENT = 0x0001 SSL_SESS_CACHE_SERVER = 0x0002 -SSL_SESS_CACHE_BOTH = SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_SERVER -SSL_SESS_CACHE_NO_AUTO_CLEAR = 0x0080 -SSL_SESS_CACHE_NO_INTERNAL_LOOKUP = 0x0100 +SSL_SESS_CACHE_BOTH = SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_SERVER +SSL_SESS_CACHE_NO_AUTO_CLEAR = 0x0080 +SSL_SESS_CACHE_NO_INTERNAL_LOOKUP = 0x0100 SSL_SESS_CACHE_NO_INTERNAL_STORE = 0x0200 SSL_SESS_CACHE_NO_INTERNAL = \ SSL_SESS_CACHE_NO_INTERNAL_LOOKUP | SSL_SESS_CACHE_NO_INTERNAL_STORE +SSL_BUILD_CHAIN_FLAG_NONE = 0x0 SSL_BUILD_CHAIN_FLAG_UNTRUSTED = 0x1 SSL_BUILD_CHAIN_FLAG_NO_ROOT = 0x2 SSL_BUILD_CHAIN_FLAG_CHECK = 0x4 @@ -345,18 +346,25 @@ class GENERAL_NAME(Structure): class GENERAL_NAMES(STACK): - stack_element_type = GENERAL_NAME - - def __init__(self, value): - super(GENERAL_NAMES, self).__init__(value) - - -class X509_NAME_ENTRY(Structure): - _fields_ = [("object", c_void_p), - ("value", c_void_p), - ("set", c_int), - ("size", c_int)] - + stack_element_type = GENERAL_NAME + + def __init__(self, value): + super(GENERAL_NAMES, self).__init__(value) + + +class STACK_OF_X509(STACK): + stack_element_type = X509 + + def __init__(self, value): + super(STACK_OF_X509, self).__init__(value) + + +class X509_NAME_ENTRY(Structure): + _fields_ = [("object", c_void_p), + ("value", c_void_p), + ("set", c_int), + ("size", c_int)] + class ASN1_OCTET_STRING(Structure): _fields_ = [("length", c_int), @@ -597,8 +605,8 @@ __all__ = [ "SSL_CB_ACCEPT_LOOP", "SSL_CB_ACCEPT_EXIT", "SSL_CB_CONNECT_LOOP", "SSL_CB_CONNECT_EXIT", "SSL_CB_HANDSHAKE_START", "SSL_CB_HANDSHAKE_DONE", - "SSL_BUILD_CHAIN_FLAG_UNTRUSTED", "SSL_BUILD_CHAIN_FLAG_NO_ROOT", "SSL_BUILD_CHAIN_FLAG_CHECK", - "SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR", "SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR", + "SSL_BUILD_CHAIN_FLAG_NONE", "SSL_BUILD_CHAIN_FLAG_UNTRUSTED", "SSL_BUILD_CHAIN_FLAG_NO_ROOT", + "SSL_BUILD_CHAIN_FLAG_CHECK", "SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR", "SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR", "SSL_FILE_TYPE_PEM", "GEN_DIRNAME", "NID_subject_alt_name", "CRYPTO_LOCK", @@ -628,6 +636,7 @@ __all__ = [ "SSL_set1_curves", "SSL_set1_curves_list", "SSL_set_mtu", "SSL_state_string_long", "SSL_alert_type_string_long", "SSL_alert_desc_string_long", + "SSL_get_peer_cert_chain", "SSL_CTX_set_cookie_cb", "OBJ_obj2txt", "decode_ASN1_STRING", "ASN1_TIME_print", "OBJ_nid2sn", @@ -736,6 +745,8 @@ map(lambda x: _make_function(*x), ( ((c_int, "ret"), (SSL, "ssl"))), ("SSL_get_peer_certificate", libssl, ((X509, "ret"), (SSL, "ssl"))), + ("SSL_get_peer_cert_chain", libssl, + ((STACK_OF_X509, "ret"), (SSL, "ssl")), False), ("SSL_read", libssl, ((c_int, "ret"), (SSL, "ssl"), (c_void_p, "buf"), (c_int, "num")), False), ("SSL_write", libssl, @@ -1149,9 +1160,19 @@ def GENERAL_NAME_print(general_name): _free_func = addressof(c_void_p.in_dll(libcrypto, "sk_free")) def sk_pop_free(stack): - _sk_pop_free(stack, _free_func) - -def i2d_X509(x509): - bio = _BIO(BIO_new(BIO_s_mem())) - _i2d_X509_bio(bio.value, x509) - return BIO_get_mem_data(bio.value) + _sk_pop_free(stack, _free_func) + +def i2d_X509(x509): + bio = _BIO(BIO_new(BIO_s_mem())) + _i2d_X509_bio(bio.value, x509) + return BIO_get_mem_data(bio.value) + +def SSL_get_peer_cert_chain(ssl): + stack = _SSL_get_peer_cert_chain(ssl) + num = sk_num(stack) + certs = [] + if num: + # why not use sk_value(): because it doesn't cast correct in this case?! + # certs = [(sk_value(stack, i)) for i in xrange(num)] + certs = [X509(_sk_value(stack, i)) for i in xrange(num)] + return stack, num, certs diff --git a/dtls/patch.py b/dtls/patch.py index 207e072..2ca94a3 100644 --- a/dtls/patch.py +++ b/dtls/patch.py @@ -198,18 +198,24 @@ def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None, self._user_config_ssl = cb_user_config_ssl # Perform method substitution and addition (without reference cycle) - self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self)) - self.listen = MethodType(_SSLSocket_listen, proxy(self)) - self.accept = MethodType(_SSLSocket_accept, proxy(self)) - self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self)) - self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self)) - -def _SSLSocket_listen(self, ignored): - if self._connected: - raise ValueError("attempt to listen on connected SSLSocket!") - if self._sslobj: - return - self._sslobj = SSLConnection(socket(_sock=self._sock), + self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self)) + self.listen = MethodType(_SSLSocket_listen, proxy(self)) + self.accept = MethodType(_SSLSocket_accept, proxy(self)) + self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self)) + self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self)) + + # Extra + self.getpeercertchain = MethodType(_getpeercertchain, proxy(self)) + +def _getpeercertchain(self, binary_form=False): + return self._sslobj.getpeercertchain(binary_form) + +def _SSLSocket_listen(self, ignored): + if self._connected: + raise ValueError("attempt to listen on connected SSLSocket!") + if self._sslobj: + return + self._sslobj = SSLConnection(socket(_sock=self._sock), self.keyfile, self.certfile, True, self.cert_reqs, self.ssl_version, self.ca_certs, diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index 59853fb..7172f2c 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -52,15 +52,15 @@ from weakref import proxy from err import openssl_error, InvalidSocketError from err import raise_ssl_error from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL -from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_CERTS +from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_SHARED_CIPHER from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT -from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR +from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR, ERR_NO_CERTS from x509 import _X509, decode_cert from tlock import tlock_init from openssl import * -from util import _Rsrc, _BIO - +from util import _Rsrc, _BIO + _logger = getLogger(__name__) PROTOCOL_DTLSv1 = 256 @@ -235,7 +235,8 @@ class SSLContext(object): return sorted([x.name for x in curves] if bAsName else [x.nid for x in curves]) def set_ecdh_curve(self, curve_name=None): - u''' + u''' Select a curve to use for ECDH(E) key exchange or set it to auto mode + Used for server only! s.a. openssl.exe ecparam -list_curves @@ -246,13 +247,13 @@ class SSLContext(object): if curve_name: retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0) avail_curves = get_elliptic_curves() - self._ctx.key = [curve for curve in avail_curves if curve.name == curve_name][0].to_EC_KEY() - retVal = SSL_CTX_set_tmp_ecdh(self._ctx, self._ctx.key) + key = [curve for curve in avail_curves if curve.name == curve_name][0].to_EC_KEY() + retVal &= SSL_CTX_set_tmp_ecdh(self._ctx, key) else: retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1) return retVal - def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NO_ROOT): + def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NONE): u''' Used for server side only! @@ -557,11 +558,11 @@ class SSLConnection(object): else: post_init = self._init_client(peer_address) - SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU) - DTLS_set_link_mtu(self._ssl.value, 1500) - if self._user_config_ssl: self._user_config_ssl(self._intf_ssl) + else: + SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU) + DTLS_set_link_mtu(self._ssl.value, 1500) SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) self._rbio.disown() self._wbio.disown() @@ -632,22 +633,25 @@ class SSLConnection(object): self._ssl.raw) dtls_peer_address = DTLSv1_listen(self._ssl.value) except openssl_error() as err: - if err.ssl_error == SSL_ERROR_WANT_READ: + if err.ssl_error == SSL_ERROR_WANT_READ: # This method must be called again to forward the next datagram _logger.debug("DTLSv1_listen must be resumed") return elif err.errqueue and err.errqueue[0][0] == ERR_WRONG_VERSION_NUMBER: _logger.debug("Wrong version number; aborting handshake") - return + raise elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH: _logger.debug("Mismatching cookie received; aborting handshake") - return - _logger.exception("Unexpected error in DTLSv1_listen") - raise - finally: - self._listening = False - self._listening_peer_address = None - if type(peer_address) is tuple: + raise + elif err.errqueue and err.errqueue[0][0] == ERR_NO_SHARED_CIPHER: + _logger.debug("No shared cipher; aborting handshake") + raise + _logger.exception("Unexpected error in DTLSv1_listen") + raise + finally: + self._listening = False + self._listening_peer_address = None + if type(peer_address) is tuple: _logger.debug("New local peer: %s", dtls_peer_address) self._pending_peer_address = peer_address else: @@ -730,35 +734,48 @@ class SSLConnection(object): Read up to len bytes and return them. Arguments: - len -- maximum number of bytes to read - - Return value: - string containing read bytes - """ - - return self._wrap_socket_library_call( - lambda: SSL_read(self._ssl.value, len, buffer), ERR_READ_TIMEOUT) - - def write(self, data): - """Write data to connection - - Write data as string of bytes. - + len -- maximum number of bytes to read + + Return value: + string containing read bytes + """ + + try: + return self._wrap_socket_library_call( + lambda: SSL_read(self._ssl.value, len, buffer), ERR_READ_TIMEOUT) + except openssl_error() as err: + if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: + raise_ssl_error(ERR_PORT_UNREACHABLE, err) + raise + + def write(self, data): + """Write data to connection + + Write data as string of bytes. + Arguments: - data -- buffer containing data to be written - - Return value: - number of bytes actually transmitted - """ - - return self._wrap_socket_library_call( - lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT) - - def shutdown(self): - """Shut down the DTLS connection - - This method attemps to complete a bidirectional shutdown between - peers. For non-blocking sockets, it should be called repeatedly until + data -- buffer containing data to be written + + Return value: + number of bytes actually transmitted + """ + + try: + ret = self._wrap_socket_library_call( + lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT) + except openssl_error() as err: + if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: + raise_ssl_error(ERR_PORT_UNREACHABLE, err) + raise + if ret: + self._handshake_done = True + return ret + + def shutdown(self): + """Shut down the DTLS connection + + This method attemps to complete a bidirectional shutdown between + peers. For non-blocking sockets, it should be called repeatedly until it no longer raises continuation request exceptions. """ @@ -809,18 +826,33 @@ class SSLConnection(object): return if binary_form: - return i2d_X509(peer_cert.value) - if self._cert_reqs == CERT_NONE: - return {} - return decode_cert(peer_cert) - - peer_certificate = getpeercert # compatibility with _ssl call interface - - def cipher(self): - """Retrieve information about the current cipher - - Return a triple consisting of cipher name, SSL protocol version defining - its use, and the number of secret bits. Return None if handshaking + return i2d_X509(peer_cert.value) + if self._cert_reqs == CERT_NONE: + return {} + return decode_cert(peer_cert) + + peer_certificate = getpeercert # compatibility with _ssl call interface + + def getpeercertchain(self, binary_form=False): + try: + stack, num, certs = SSL_get_peer_cert_chain(self._ssl.value) + except openssl_error(): + return + + peer_cert_chain = [_Rsrc(cert) for cert in certs] + ret = [] + if binary_form: + ret = [i2d_X509(x.value) for x in peer_cert_chain] + elif len(peer_cert_chain): + ret = [decode_cert(x) for x in peer_cert_chain] + + return ret + + def cipher(self): + """Retrieve information about the current cipher + + Return a triple consisting of cipher name, SSL protocol version defining + its use, and the number of secret bits. Return None if handshaking has not been completed. """ diff --git a/dtls/test/unit.py b/dtls/test/unit.py index c2a521d..eb7fa08 100644 --- a/dtls/test/unit.py +++ b/dtls/test/unit.py @@ -532,18 +532,20 @@ class ThreadedEchoServer(threading.Thread): if acc_ret: newconn, connaddr = acc_ret if test_support.verbose and self.chatty: - sys.stdout.write(' server: new connection from ' - + str(connaddr) + '\n') - handler = self.ConnectionHandler(self, newconn) - handler.start() - except socket.timeout: - pass - except KeyboardInterrupt: - self.stop() - self.sock.close() - - def register_handler(self, add): - with self.num_handlers_lock: + sys.stdout.write(' server: new connection from ' + + str(connaddr) + '\n') + handler = self.ConnectionHandler(self, newconn) + handler.start() + except socket.timeout: + pass + except ssl.SSLError: + pass + except KeyboardInterrupt: + self.stop() + self.sock.close() + + def register_handler(self, add): + with self.num_handlers_lock: if add: self.num_handlers += 1 else: