Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()

* dtls/__init__.py: Added DtlsSocket() from wrapper and aliases for wrap_server() and wrap_client()
* dtls/err.py: Added patch_ssl_errors() to patch ssl-Module with ERR_* constants
* dtls/patch.py: Patched ssl-Module with SSL_BUILD_* constants and added call to patch_ssl_errors()
* dtls/wrapper.py:
	- Added a server and client function to alias/wrap DtlsSocket() creation
	- Cleanup of DtlsSocket.__init__()
	- Cleanup of exception handling in all member methods
	- Cleanup sendto() from client: no endless loop and first do a connect if not already connected
* dtls/test/unit_wrapper.py: Adopt the changes made described above
incoming
mcfreis 2017-03-23 14:08:20 +01:00
parent d12b23ba9f
commit dade3b8213
6 changed files with 166 additions and 133 deletions

View File

@ -1,3 +1,17 @@
2017-03-23 Björn Freise <mcfreis@gmx.net>
Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()
* dtls/__init__.py: Added DtlsSocket() from wrapper and aliases for wrap_server() and wrap_client()
* dtls/err.py: Added patch_ssl_errors() to patch ssl-Module with ERR_* constants
* dtls/patch.py: Patched ssl-Module with SSL_BUILD_* constants and added call to patch_ssl_errors()
* dtls/wrapper.py:
- Added a server and client function to alias/wrap DtlsSocket() creation
- Cleanup of DtlsSocket.__init__()
- Cleanup of exception handling in all member methods
- Cleanup sendto() from client: no endless loop and first do a connect if not already connected
* dtls/test/unit_wrapper.py: Adopt the changes made described above
2017-03-17 Björn Freise <mcfreis@gmx.net> 2017-03-17 Björn Freise <mcfreis@gmx.net>
Added a wrapper for a DTLS-Socket either as client or server - including unit tests Added a wrapper for a DTLS-Socket either as client or server - including unit tests

View File

@ -61,4 +61,4 @@ _prep_bins() # prepare before module imports
from patch import do_patch from patch import do_patch
from sslconnection import SSLContext, SSL, SSLConnection from sslconnection import SSLContext, SSL, SSLConnection
from demux import force_routing_demux, reset_default_demux from demux import force_routing_demux, reset_default_demux
import err as error_codes from wrapper import DtlsSocket, client as wrap_client, server as wrap_server

View File

@ -54,7 +54,14 @@ ERR_COOKIE_MISMATCH = 0x1408A134
ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086 ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086
ERR_NO_SHARED_CIPHER = 0x1408A0C1 ERR_NO_SHARED_CIPHER = 0x1408A0C1
ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5 ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5
ERR_TLSV1_ALERT_UNKNOWN_CA = 0x14102418
def patch_ssl_errors():
import ssl
errors = [i for i in globals().iteritems() if type(i[1]) == int and str(i[0]).startswith('ERR_')]
for k, v in errors:
if not hasattr(ssl, k):
setattr(ssl, k, v)
class SSLError(socket_error): class SSLError(socket_error):
"""This exception is raised by modules in the dtls package.""" """This exception is raised by modules in the dtls package."""

View File

@ -43,7 +43,9 @@ import errno
from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2 from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \
SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK, SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR, SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR
from err import raise_as_ssl_module_error, patch_ssl_errors
def do_patch(): def do_patch():
@ -64,10 +66,17 @@ def do_patch():
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
ssl.SSL_BUILD_CHAIN_FLAG_NONE = SSL_BUILD_CHAIN_FLAG_NONE
ssl.SSL_BUILD_CHAIN_FLAG_UNTRUSTED = SSL_BUILD_CHAIN_FLAG_UNTRUSTED
ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT = SSL_BUILD_CHAIN_FLAG_NO_ROOT
ssl.SSL_BUILD_CHAIN_FLAG_CHECK = SSL_BUILD_CHAIN_FLAG_CHECK
ssl.SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR = SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR
ssl.SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR = SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR
_orig_SSLSocket_init = ssl.SSLSocket.__init__ _orig_SSLSocket_init = ssl.SSLSocket.__init__
_orig_get_server_certificate = ssl.get_server_certificate _orig_get_server_certificate = ssl.get_server_certificate
ssl.SSLSocket.__init__ = _SSLSocket_init ssl.SSLSocket.__init__ = _SSLSocket_init
ssl.get_server_certificate = _get_server_certificate ssl.get_server_certificate = _get_server_certificate
patch_ssl_errors()
raise_as_ssl_module_error() raise_as_ssl_module_error()
def _wrap_socket(sock, keyfile=None, certfile=None, def _wrap_socket(sock, keyfile=None, certfile=None,

View File

@ -15,8 +15,7 @@ from logging import basicConfig, DEBUG, getLogger
_logger = getLogger(__name__) _logger = getLogger(__name__)
import ssl import ssl
from dtls import do_patch, error_codes from dtls import DtlsSocket
from dtls.wrapper import DtlsSocket, SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT
HOST = "localhost" HOST = "localhost"
@ -186,7 +185,7 @@ tests = [
'client_sigalgs': None}, 'client_sigalgs': None},
'result': 'result':
{'ret_success': False, {'ret_success': False,
'error_code': error_codes.ERR_WRONG_SSL_VERSION, 'error_code': ssl.ERR_WRONG_SSL_VERSION,
'exception': None}}, 'exception': None}},
{'testcase': {'testcase':
{'name': 'certificate verify fails', {'name': 'certificate verify fails',
@ -209,7 +208,7 @@ tests = [
'client_sigalgs': None}, 'client_sigalgs': None},
'result': 'result':
{'ret_success': False, {'ret_success': False,
'error_code': error_codes.ERR_CERTIFICATE_VERIFY_FAILED, 'error_code': ssl.ERR_CERTIFICATE_VERIFY_FAILED,
'exception': None}}, 'exception': None}},
{'testcase': {'testcase':
{'name': 'no matching curve', {'name': 'no matching curve',
@ -232,7 +231,7 @@ tests = [
'client_sigalgs': None}, 'client_sigalgs': None},
'result': 'result':
{'ret_success': False, {'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE, 'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}}, 'exception': None}},
{'testcase': {'testcase':
{'name': 'matching curve', {'name': 'matching curve',
@ -278,7 +277,7 @@ tests = [
'client_sigalgs': None}, 'client_sigalgs': None},
'result': 'result':
{'ret_success': False, {'ret_success': False,
'error_code': error_codes.ERR_PORT_UNREACHABLE, 'error_code': ssl.ERR_PORT_UNREACHABLE,
'exception': None}}, 'exception': None}},
{'testcase': {'testcase':
{'name': 'no matching sigalgs', {'name': 'no matching sigalgs',
@ -301,7 +300,7 @@ tests = [
'client_sigalgs': "RSA+SHA256"}, 'client_sigalgs': "RSA+SHA256"},
'result': 'result':
{'ret_success': False, {'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE, 'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}}, 'exception': None}},
{'testcase': {'testcase':
{'name': 'matching sigalgs', {'name': 'matching sigalgs',
@ -347,7 +346,7 @@ tests = [
'client_sigalgs': None}, 'client_sigalgs': None},
'result': 'result':
{'ret_success': False, {'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE, 'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}}, 'exception': None}},
{'testcase': {'testcase':
{'name': 'matching cipher', {'name': 'matching cipher',
@ -493,13 +492,8 @@ class TestSequenceMeta(type):
class WrapperTests(unittest.TestCase): class WrapperTests(unittest.TestCase):
__metaclass__ = TestSequenceMeta __metaclass__ = TestSequenceMeta
def setUp(self):
super(WrapperTests, self).setUp()
do_patch()
def test_build_cert_chain(self): def test_build_cert_chain(self):
steps = [SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT] steps = [ssl.SSL_BUILD_CHAIN_FLAG_NONE, ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT]
chatty, connectionchatty = CHATTY, CHATTY_CLIENT chatty, connectionchatty = CHATTY, CHATTY_CLIENT
indata = 'FOO' indata = 'FOO'
certs = dict() certs = dict()

View File

@ -38,12 +38,36 @@ import socket
from patch import do_patch from patch import do_patch
do_patch() do_patch()
from sslconnection import SSLContext, SSL from sslconnection import SSLContext, SSL
from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \ import err as err_codes
SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK
_logger = getLogger(__name__) _logger = getLogger(__name__)
def client(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None,
do_handshake_on_connect=True, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE)
def server(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=False, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options)
class DtlsSocket(object): class DtlsSocket(object):
class _ClientSession(object): class _ClientSession(object):
@ -57,7 +81,7 @@ class DtlsSocket(object):
return self.host, self.port return self.host, self.port
def __init__(self, def __init__(self,
peerOrSock, sock=None,
keyfile=None, keyfile=None,
certfile=None, certfile=None,
server_side=False, server_side=False,
@ -71,13 +95,12 @@ class DtlsSocket(object):
sigalgs=None, sigalgs=None,
user_mtu=None, user_mtu=None,
server_key_exchange_curve=None, server_key_exchange_curve=None,
server_cert_options=SSL_BUILD_CHAIN_FLAG_NONE): server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
if server_cert_options is None: if server_cert_options is None:
server_cert_options = SSL_BUILD_CHAIN_FLAG_NONE server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
self._ssl_logging = False self._ssl_logging = False
self._peer = None
self._server_side = server_side self._server_side = server_side
self._ciphers = ciphers self._ciphers = ciphers
self._curves = curves self._curves = curves
@ -87,15 +110,11 @@ class DtlsSocket(object):
self._server_cert_options = server_cert_options self._server_cert_options = server_cert_options
# Default socket creation # Default socket creation
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if isinstance(peerOrSock, tuple): if isinstance(sock, socket.socket):
# Address tuple _sock = sock
self._peer = peerOrSock
else:
# Socket, use given
sock = peerOrSock
self._sock = ssl.wrap_socket(sock, self._sock = ssl.wrap_socket(_sock,
keyfile=keyfile, keyfile=keyfile,
certfile=certfile, certfile=certfile,
server_side=self._server_side, server_side=self._server_side,
@ -112,13 +131,6 @@ class DtlsSocket(object):
self._clients = {} self._clients = {}
self._timeout = None self._timeout = None
if self._peer:
self._sock.bind(self._peer)
self._sock.listen(0)
else:
if self._peer:
self._sock.connect(self._peer)
def __getattr__(self, item): def __getattr__(self, item):
if hasattr(self, "_sock") and hasattr(self._sock, item): if hasattr(self, "_sock") and hasattr(self._sock, item):
return getattr(self._sock, item) return getattr(self._sock, item)
@ -159,17 +171,12 @@ class DtlsSocket(object):
for cli in self._clients.keys(): for cli in self._clients.keys():
cli.close() cli.close()
else: else:
try:
self._sock.unwrap() self._sock.unwrap()
except:
pass
self._sock.close() self._sock.close()
def write(self, data):
# return self._sock.write(data)
return self.sendto(data, self._peer)
def read(self, len=1024):
# return self._sock.read(len=len)
return self.recvfrom(len)[0]
def recvfrom(self, bufsize, flags=0): def recvfrom(self, bufsize, flags=0):
if self._server_side: if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags) return self._recvfrom_on_server_side(bufsize, flags=flags)
@ -180,11 +187,13 @@ class DtlsSocket(object):
try: try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout) r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout as e_timeout: except socket.timeout:
raise e_timeout # __Nothing__ received from any client
raise socket.timeout
try: try:
for conn in r: # type: ssl.SSLSocket for conn in r:
_last_peer = conn.getpeername() if conn._connected else None
if self._sockIsServerSock(conn): if self._sockIsServerSock(conn):
# Connect # Connect
self._clientAccept(conn) self._clientAccept(conn)
@ -195,38 +204,43 @@ class DtlsSocket(object):
# Normal read # Normal read
else: else:
buf = self._clientRead(conn, bufsize) buf = self._clientRead(conn, bufsize)
if buf and conn in self._clients: if buf:
if conn in self._clients:
return buf, self._clients[conn].getAddr() return buf, self._clients[conn].getAddr()
else:
_logger.debug('Received data from an already disconnected client!')
except Exception as e: except Exception as e:
setattr(e, 'peer', _last_peer)
raise e raise e
try: try:
for conn in self._getClientReadingSockets(): for conn in self._getClientReadingSockets():
if conn.get_timeout(): if conn.get_timeout():
conn.handle_timeout() ret = conn.handle_timeout()
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
except Exception as e: except Exception as e:
raise e raise e
# __No_data__ received from any client
raise socket.timeout raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags): def _recvfrom_on_client_side(self, bufsize, flags):
try: try:
buf = self._sock.recv(bufsize, flags) buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e_ssl: except ssl.SSLError as e:
if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN: if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
return '', self._peer
elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
raise e_ssl
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass pass
else:
raise e
else: else:
if buf: if buf:
return buf, self._peer return buf, self._sock.getpeername()
# __No_data__ received from any client
raise socket.timeout raise socket.timeout
def sendto(self, buf, address): def sendto(self, buf, address):
@ -242,19 +256,13 @@ class DtlsSocket(object):
return 0 return 0
def _sendto_from_client_side(self, buf, address): def _sendto_from_client_side(self, buf, address):
while True:
try: try:
if not self._sock._connected:
self._sock.connect(address)
bytes_sent = self._sock.send(buf) bytes_sent = self._sock.send(buf)
except ssl.SSLError as e_ssl: except ssl.SSLError as e:
if str(e_ssl).startswith("503:"): raise e
# The write operation timed out
continue
raise e_ssl
else:
if bytes_sent:
break
return bytes_sent return bytes_sent
@ -278,14 +286,15 @@ class DtlsSocket(object):
ret = conn.accept() ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret))) _logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e_accept: except Exception as e:
raise e_accept raise e
else: else:
if ret: if ret:
client, addr = ret client, addr = ret
host, port = addr host, port = addr
if client in self._clients: if client in self._clients:
_logger.debug('Client already connected %s' % str(client))
raise ValueError raise ValueError
self._clients[client] = self._ClientSession(host=host, port=port) self._clients[client] = self._ClientSession(host=host, port=port)
@ -297,17 +306,16 @@ class DtlsSocket(object):
try: try:
conn.do_handshake() conn.do_handshake()
_logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr()))) _logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True self._clients[conn].handshake_done = True
except ssl.SSLError as e_handshake: except ssl.SSLError as e:
if str(e_handshake).startswith("504:"): if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ:
pass pass
else: else:
raise e_handshake self._clientDrop(conn, error=e)
raise e
def _clientRead(self, conn, bufsize=4096): def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60) _logger.debug('*' * 60)
@ -317,13 +325,11 @@ class DtlsSocket(object):
ret = conn.recv(bufsize) ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret)))) _logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e_read: except ssl.SSLError as e:
if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN: if e.args[0] == ssl.SSL_ERROR_WANT_READ:
self._clientDrop(conn)
elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
self._clientDrop(conn, error=e_read)
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass pass
else:
self._clientDrop(conn, error=e)
return ret return ret
@ -338,8 +344,8 @@ class DtlsSocket(object):
ret = conn.send(_data) ret = conn.send(_data)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret))) _logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e_write: except Exception as e:
raise e_write raise e
return ret return ret
@ -354,8 +360,11 @@ class DtlsSocket(object):
if conn in self._clients: if conn in self._clients:
del self._clients[conn] del self._clients[conn]
try:
conn.unwrap() conn.unwrap()
except:
pass
conn.close() conn.close()
except Exception as e_drop: except Exception as e:
pass pass