diff --git a/ChangeLog b/ChangeLog index b08cbf0..f2e7b27 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,17 @@ +2017-03-23 Björn Freise + + Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client() + + * dtls/__init__.py: Added DtlsSocket() from wrapper and aliases for wrap_server() and wrap_client() + * dtls/err.py: Added patch_ssl_errors() to patch ssl-Module with ERR_* constants + * dtls/patch.py: Patched ssl-Module with SSL_BUILD_* constants and added call to patch_ssl_errors() + * dtls/wrapper.py: + - Added a server and client function to alias/wrap DtlsSocket() creation + - Cleanup of DtlsSocket.__init__() + - Cleanup of exception handling in all member methods + - Cleanup sendto() from client: no endless loop and first do a connect if not already connected + * dtls/test/unit_wrapper.py: Adopt the changes made described above + 2017-03-17 Björn Freise Added a wrapper for a DTLS-Socket either as client or server - including unit tests diff --git a/dtls/__init__.py b/dtls/__init__.py index e263633..86d023c 100644 --- a/dtls/__init__.py +++ b/dtls/__init__.py @@ -53,12 +53,12 @@ def _prep_bins(): for prebuilt_file in files: try: copy(path.join(prebuilt_path, prebuilt_file), package_root) - except IOError: - pass - -_prep_bins() # prepare before module imports - -from patch import do_patch -from sslconnection import SSLContext, SSL, SSLConnection -from demux import force_routing_demux, reset_default_demux -import err as error_codes + except IOError: + pass + +_prep_bins() # prepare before module imports + +from patch import do_patch +from sslconnection import SSLContext, SSL, SSLConnection +from demux import force_routing_demux, reset_default_demux +from wrapper import DtlsSocket, client as wrap_client, server as wrap_server diff --git a/dtls/err.py b/dtls/err.py index b869288..25e8e43 100644 --- a/dtls/err.py +++ b/dtls/err.py @@ -42,24 +42,31 @@ SSL_ERROR_WANT_ACCEPT = 8 ERR_BOTH_KEY_CERT_FILES = 500 ERR_BOTH_KEY_CERT_FILES_SVR = 298 ERR_NO_CERTS = 331 -ERR_NO_CIPHER = 501 -ERR_READ_TIMEOUT = 502 -ERR_WRITE_TIMEOUT = 503 -ERR_HANDSHAKE_TIMEOUT = 504 -ERR_PORT_UNREACHABLE = 505 - -ERR_WRONG_SSL_VERSION = 0x1409210A -ERR_WRONG_VERSION_NUMBER = 0x1408A10B -ERR_COOKIE_MISMATCH = 0x1408A134 -ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086 -ERR_NO_SHARED_CIPHER = 0x1408A0C1 -ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5 - - -class SSLError(socket_error): - """This exception is raised by modules in the dtls package.""" - def __init__(self, *args): - super(SSLError, self).__init__(*args) +ERR_NO_CIPHER = 501 +ERR_READ_TIMEOUT = 502 +ERR_WRITE_TIMEOUT = 503 +ERR_HANDSHAKE_TIMEOUT = 504 +ERR_PORT_UNREACHABLE = 505 + +ERR_WRONG_SSL_VERSION = 0x1409210A +ERR_WRONG_VERSION_NUMBER = 0x1408A10B +ERR_COOKIE_MISMATCH = 0x1408A134 +ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086 +ERR_NO_SHARED_CIPHER = 0x1408A0C1 +ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5 +ERR_TLSV1_ALERT_UNKNOWN_CA = 0x14102418 + +def patch_ssl_errors(): + import ssl + errors = [i for i in globals().iteritems() if type(i[1]) == int and str(i[0]).startswith('ERR_')] + for k, v in errors: + if not hasattr(ssl, k): + setattr(ssl, k, v) + +class SSLError(socket_error): + """This exception is raised by modules in the dtls package.""" + def __init__(self, *args): + super(SSLError, self).__init__(*args) class InvalidSocketError(Exception): diff --git a/dtls/patch.py b/dtls/patch.py index 2ca94a3..b7c7062 100644 --- a/dtls/patch.py +++ b/dtls/patch.py @@ -43,7 +43,9 @@ import errno from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2 from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO -from err import raise_as_ssl_module_error +from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \ + SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK, SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR, SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR +from err import raise_as_ssl_module_error, patch_ssl_errors def do_patch(): @@ -64,10 +66,17 @@ def do_patch(): ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO + ssl.SSL_BUILD_CHAIN_FLAG_NONE = SSL_BUILD_CHAIN_FLAG_NONE + ssl.SSL_BUILD_CHAIN_FLAG_UNTRUSTED = SSL_BUILD_CHAIN_FLAG_UNTRUSTED + ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT = SSL_BUILD_CHAIN_FLAG_NO_ROOT + ssl.SSL_BUILD_CHAIN_FLAG_CHECK = SSL_BUILD_CHAIN_FLAG_CHECK + ssl.SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR = SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR + ssl.SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR = SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR _orig_SSLSocket_init = ssl.SSLSocket.__init__ _orig_get_server_certificate = ssl.get_server_certificate ssl.SSLSocket.__init__ = _SSLSocket_init ssl.get_server_certificate = _get_server_certificate + patch_ssl_errors() raise_as_ssl_module_error() def _wrap_socket(sock, keyfile=None, certfile=None, @@ -198,24 +207,24 @@ def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None, self._user_config_ssl = cb_user_config_ssl # Perform method substitution and addition (without reference cycle) - self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self)) - self.listen = MethodType(_SSLSocket_listen, proxy(self)) - self.accept = MethodType(_SSLSocket_accept, proxy(self)) - self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self)) - self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self)) - - # Extra - self.getpeercertchain = MethodType(_getpeercertchain, proxy(self)) - -def _getpeercertchain(self, binary_form=False): - return self._sslobj.getpeercertchain(binary_form) - -def _SSLSocket_listen(self, ignored): - if self._connected: - raise ValueError("attempt to listen on connected SSLSocket!") - if self._sslobj: - return - self._sslobj = SSLConnection(socket(_sock=self._sock), + self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self)) + self.listen = MethodType(_SSLSocket_listen, proxy(self)) + self.accept = MethodType(_SSLSocket_accept, proxy(self)) + self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self)) + self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self)) + + # Extra + self.getpeercertchain = MethodType(_getpeercertchain, proxy(self)) + +def _getpeercertchain(self, binary_form=False): + return self._sslobj.getpeercertchain(binary_form) + +def _SSLSocket_listen(self, ignored): + if self._connected: + raise ValueError("attempt to listen on connected SSLSocket!") + if self._sslobj: + return + self._sslobj = SSLConnection(socket(_sock=self._sock), self.keyfile, self.certfile, True, self.cert_reqs, self.ssl_version, self.ca_certs, diff --git a/dtls/test/unit_wrapper.py b/dtls/test/unit_wrapper.py index 582e3e8..ffb926c 100644 --- a/dtls/test/unit_wrapper.py +++ b/dtls/test/unit_wrapper.py @@ -15,8 +15,7 @@ from logging import basicConfig, DEBUG, getLogger _logger = getLogger(__name__) import ssl -from dtls import do_patch, error_codes -from dtls.wrapper import DtlsSocket, SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT +from dtls import DtlsSocket HOST = "localhost" @@ -186,7 +185,7 @@ tests = [ 'client_sigalgs': None}, 'result': {'ret_success': False, - 'error_code': error_codes.ERR_WRONG_SSL_VERSION, + 'error_code': ssl.ERR_WRONG_SSL_VERSION, 'exception': None}}, {'testcase': {'name': 'certificate verify fails', @@ -209,7 +208,7 @@ tests = [ 'client_sigalgs': None}, 'result': {'ret_success': False, - 'error_code': error_codes.ERR_CERTIFICATE_VERIFY_FAILED, + 'error_code': ssl.ERR_CERTIFICATE_VERIFY_FAILED, 'exception': None}}, {'testcase': {'name': 'no matching curve', @@ -232,7 +231,7 @@ tests = [ 'client_sigalgs': None}, 'result': {'ret_success': False, - 'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE, + 'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE, 'exception': None}}, {'testcase': {'name': 'matching curve', @@ -278,7 +277,7 @@ tests = [ 'client_sigalgs': None}, 'result': {'ret_success': False, - 'error_code': error_codes.ERR_PORT_UNREACHABLE, + 'error_code': ssl.ERR_PORT_UNREACHABLE, 'exception': None}}, {'testcase': {'name': 'no matching sigalgs', @@ -301,7 +300,7 @@ tests = [ 'client_sigalgs': "RSA+SHA256"}, 'result': {'ret_success': False, - 'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE, + 'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE, 'exception': None}}, {'testcase': {'name': 'matching sigalgs', @@ -347,7 +346,7 @@ tests = [ 'client_sigalgs': None}, 'result': {'ret_success': False, - 'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE, + 'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE, 'exception': None}}, {'testcase': {'name': 'matching cipher', @@ -493,13 +492,8 @@ class TestSequenceMeta(type): class WrapperTests(unittest.TestCase): __metaclass__ = TestSequenceMeta - def setUp(self): - super(WrapperTests, self).setUp() - - do_patch() - def test_build_cert_chain(self): - steps = [SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT] + steps = [ssl.SSL_BUILD_CHAIN_FLAG_NONE, ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT] chatty, connectionchatty = CHATTY, CHATTY_CLIENT indata = 'FOO' certs = dict() diff --git a/dtls/wrapper.py b/dtls/wrapper.py index ef525c5..87f9f46 100644 --- a/dtls/wrapper.py +++ b/dtls/wrapper.py @@ -38,12 +38,36 @@ import socket from patch import do_patch do_patch() from sslconnection import SSLContext, SSL -from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \ - SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK +import err as err_codes _logger = getLogger(__name__) +def client(sock, keyfile=None, certfile=None, + cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None, + do_handshake_on_connect=True, suppress_ragged_eofs=True, + ciphers=None, curves=None, sigalgs=None, user_mtu=None): + + return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False, + cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu, + server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE) + + +def server(sock, keyfile=None, certfile=None, + cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None, + do_handshake_on_connect=False, suppress_ragged_eofs=True, + ciphers=None, curves=None, sigalgs=None, user_mtu=None, + server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE): + + return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True, + cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu, + server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options) + + class DtlsSocket(object): class _ClientSession(object): @@ -57,7 +81,7 @@ class DtlsSocket(object): return self.host, self.port def __init__(self, - peerOrSock, + sock=None, keyfile=None, certfile=None, server_side=False, @@ -71,13 +95,12 @@ class DtlsSocket(object): sigalgs=None, user_mtu=None, server_key_exchange_curve=None, - server_cert_options=SSL_BUILD_CHAIN_FLAG_NONE): + server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE): if server_cert_options is None: - server_cert_options = SSL_BUILD_CHAIN_FLAG_NONE + server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE self._ssl_logging = False - self._peer = None self._server_side = server_side self._ciphers = ciphers self._curves = curves @@ -87,15 +110,11 @@ class DtlsSocket(object): self._server_cert_options = server_cert_options # Default socket creation - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - if isinstance(peerOrSock, tuple): - # Address tuple - self._peer = peerOrSock - else: - # Socket, use given - sock = peerOrSock + _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + if isinstance(sock, socket.socket): + _sock = sock - self._sock = ssl.wrap_socket(sock, + self._sock = ssl.wrap_socket(_sock, keyfile=keyfile, certfile=certfile, server_side=self._server_side, @@ -112,13 +131,6 @@ class DtlsSocket(object): self._clients = {} self._timeout = None - if self._peer: - self._sock.bind(self._peer) - self._sock.listen(0) - else: - if self._peer: - self._sock.connect(self._peer) - def __getattr__(self, item): if hasattr(self, "_sock") and hasattr(self._sock, item): return getattr(self._sock, item) @@ -159,17 +171,12 @@ class DtlsSocket(object): for cli in self._clients.keys(): cli.close() else: - self._sock.unwrap() + try: + self._sock.unwrap() + except: + pass self._sock.close() - def write(self, data): - # return self._sock.write(data) - return self.sendto(data, self._peer) - - def read(self, len=1024): - # return self._sock.read(len=len) - return self.recvfrom(len)[0] - def recvfrom(self, bufsize, flags=0): if self._server_side: return self._recvfrom_on_server_side(bufsize, flags=flags) @@ -180,11 +187,13 @@ class DtlsSocket(object): try: r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout) - except socket.timeout as e_timeout: - raise e_timeout + except socket.timeout: + # __Nothing__ received from any client + raise socket.timeout try: - for conn in r: # type: ssl.SSLSocket + for conn in r: + _last_peer = conn.getpeername() if conn._connected else None if self._sockIsServerSock(conn): # Connect self._clientAccept(conn) @@ -195,38 +204,43 @@ class DtlsSocket(object): # Normal read else: buf = self._clientRead(conn, bufsize) - if buf and conn in self._clients: - return buf, self._clients[conn].getAddr() + if buf: + if conn in self._clients: + return buf, self._clients[conn].getAddr() + else: + _logger.debug('Received data from an already disconnected client!') except Exception as e: + setattr(e, 'peer', _last_peer) raise e try: for conn in self._getClientReadingSockets(): if conn.get_timeout(): - conn.handle_timeout() + ret = conn.handle_timeout() + _logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret)) except Exception as e: raise e + # __No_data__ received from any client raise socket.timeout def _recvfrom_on_client_side(self, bufsize, flags): try: buf = self._sock.recv(bufsize, flags) - except ssl.SSLError as e_ssl: - if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN: - return '', self._peer - elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]: - raise e_ssl - else: # like in [ssl.SSL_ERROR_WANT_READ, ...] + except ssl.SSLError as e: + if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ: pass + else: + raise e else: if buf: - return buf, self._peer + return buf, self._sock.getpeername() + # __No_data__ received from any client raise socket.timeout def sendto(self, buf, address): @@ -242,19 +256,13 @@ class DtlsSocket(object): return 0 def _sendto_from_client_side(self, buf, address): - while True: - try: - bytes_sent = self._sock.send(buf) + try: + if not self._sock._connected: + self._sock.connect(address) + bytes_sent = self._sock.send(buf) - except ssl.SSLError as e_ssl: - if str(e_ssl).startswith("503:"): - # The write operation timed out - continue - raise e_ssl - - else: - if bytes_sent: - break + except ssl.SSLError as e: + raise e return bytes_sent @@ -278,14 +286,15 @@ class DtlsSocket(object): ret = conn.accept() _logger.debug('Accept returned with ... %s' % (str(ret))) - except Exception as e_accept: - raise e_accept + except Exception as e: + raise e else: if ret: client, addr = ret host, port = addr if client in self._clients: + _logger.debug('Client already connected %s' % str(client)) raise ValueError self._clients[client] = self._ClientSession(host=host, port=port) @@ -297,17 +306,16 @@ class DtlsSocket(object): try: conn.do_handshake() - _logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr()))) + _logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr()))) self._clients[conn].handshake_done = True - except ssl.SSLError as e_handshake: - if str(e_handshake).startswith("504:"): - pass - elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ: + except ssl.SSLError as e: + if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ: pass else: - raise e_handshake + self._clientDrop(conn, error=e) + raise e def _clientRead(self, conn, bufsize=4096): _logger.debug('*' * 60) @@ -317,13 +325,11 @@ class DtlsSocket(object): ret = conn.recv(bufsize) _logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret)))) - except ssl.SSLError as e_read: - if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN: - self._clientDrop(conn) - elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]: - self._clientDrop(conn, error=e_read) - else: # like in [ssl.SSL_ERROR_WANT_READ, ...] + except ssl.SSLError as e: + if e.args[0] == ssl.SSL_ERROR_WANT_READ: pass + else: + self._clientDrop(conn, error=e) return ret @@ -338,8 +344,8 @@ class DtlsSocket(object): ret = conn.send(_data) _logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret))) - except Exception as e_write: - raise e_write + except Exception as e: + raise e return ret @@ -354,8 +360,11 @@ class DtlsSocket(object): if conn in self._clients: del self._clients[conn] - conn.unwrap() + try: + conn.unwrap() + except: + pass conn.close() - except Exception as e_drop: + except Exception as e: pass