Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()
* dtls/__init__.py: Added DtlsSocket() from wrapper and aliases for wrap_server() and wrap_client() * dtls/err.py: Added patch_ssl_errors() to patch ssl-Module with ERR_* constants * dtls/patch.py: Patched ssl-Module with SSL_BUILD_* constants and added call to patch_ssl_errors() * dtls/wrapper.py: - Added a server and client function to alias/wrap DtlsSocket() creation - Cleanup of DtlsSocket.__init__() - Cleanup of exception handling in all member methods - Cleanup sendto() from client: no endless loop and first do a connect if not already connected * dtls/test/unit_wrapper.py: Adopt the changes made described aboveincoming
parent
d12b23ba9f
commit
dade3b8213
14
ChangeLog
14
ChangeLog
|
@ -1,3 +1,17 @@
|
|||
2017-03-23 Björn Freise <mcfreis@gmx.net>
|
||||
|
||||
Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()
|
||||
|
||||
* dtls/__init__.py: Added DtlsSocket() from wrapper and aliases for wrap_server() and wrap_client()
|
||||
* dtls/err.py: Added patch_ssl_errors() to patch ssl-Module with ERR_* constants
|
||||
* dtls/patch.py: Patched ssl-Module with SSL_BUILD_* constants and added call to patch_ssl_errors()
|
||||
* dtls/wrapper.py:
|
||||
- Added a server and client function to alias/wrap DtlsSocket() creation
|
||||
- Cleanup of DtlsSocket.__init__()
|
||||
- Cleanup of exception handling in all member methods
|
||||
- Cleanup sendto() from client: no endless loop and first do a connect if not already connected
|
||||
* dtls/test/unit_wrapper.py: Adopt the changes made described above
|
||||
|
||||
2017-03-17 Björn Freise <mcfreis@gmx.net>
|
||||
|
||||
Added a wrapper for a DTLS-Socket either as client or server - including unit tests
|
||||
|
|
|
@ -61,4 +61,4 @@ _prep_bins() # prepare before module imports
|
|||
from patch import do_patch
|
||||
from sslconnection import SSLContext, SSL, SSLConnection
|
||||
from demux import force_routing_demux, reset_default_demux
|
||||
import err as error_codes
|
||||
from wrapper import DtlsSocket, client as wrap_client, server as wrap_server
|
||||
|
|
|
@ -54,7 +54,14 @@ ERR_COOKIE_MISMATCH = 0x1408A134
|
|||
ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086
|
||||
ERR_NO_SHARED_CIPHER = 0x1408A0C1
|
||||
ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5
|
||||
ERR_TLSV1_ALERT_UNKNOWN_CA = 0x14102418
|
||||
|
||||
def patch_ssl_errors():
|
||||
import ssl
|
||||
errors = [i for i in globals().iteritems() if type(i[1]) == int and str(i[0]).startswith('ERR_')]
|
||||
for k, v in errors:
|
||||
if not hasattr(ssl, k):
|
||||
setattr(ssl, k, v)
|
||||
|
||||
class SSLError(socket_error):
|
||||
"""This exception is raised by modules in the dtls package."""
|
||||
|
|
|
@ -43,7 +43,9 @@ import errno
|
|||
|
||||
from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
|
||||
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
|
||||
from err import raise_as_ssl_module_error
|
||||
from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \
|
||||
SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK, SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR, SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR
|
||||
from err import raise_as_ssl_module_error, patch_ssl_errors
|
||||
|
||||
|
||||
def do_patch():
|
||||
|
@ -64,10 +66,17 @@ def do_patch():
|
|||
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
|
||||
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
|
||||
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
|
||||
ssl.SSL_BUILD_CHAIN_FLAG_NONE = SSL_BUILD_CHAIN_FLAG_NONE
|
||||
ssl.SSL_BUILD_CHAIN_FLAG_UNTRUSTED = SSL_BUILD_CHAIN_FLAG_UNTRUSTED
|
||||
ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT = SSL_BUILD_CHAIN_FLAG_NO_ROOT
|
||||
ssl.SSL_BUILD_CHAIN_FLAG_CHECK = SSL_BUILD_CHAIN_FLAG_CHECK
|
||||
ssl.SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR = SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR
|
||||
ssl.SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR = SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR
|
||||
_orig_SSLSocket_init = ssl.SSLSocket.__init__
|
||||
_orig_get_server_certificate = ssl.get_server_certificate
|
||||
ssl.SSLSocket.__init__ = _SSLSocket_init
|
||||
ssl.get_server_certificate = _get_server_certificate
|
||||
patch_ssl_errors()
|
||||
raise_as_ssl_module_error()
|
||||
|
||||
def _wrap_socket(sock, keyfile=None, certfile=None,
|
||||
|
|
|
@ -15,8 +15,7 @@ from logging import basicConfig, DEBUG, getLogger
|
|||
_logger = getLogger(__name__)
|
||||
|
||||
import ssl
|
||||
from dtls import do_patch, error_codes
|
||||
from dtls.wrapper import DtlsSocket, SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT
|
||||
from dtls import DtlsSocket
|
||||
|
||||
|
||||
HOST = "localhost"
|
||||
|
@ -186,7 +185,7 @@ tests = [
|
|||
'client_sigalgs': None},
|
||||
'result':
|
||||
{'ret_success': False,
|
||||
'error_code': error_codes.ERR_WRONG_SSL_VERSION,
|
||||
'error_code': ssl.ERR_WRONG_SSL_VERSION,
|
||||
'exception': None}},
|
||||
{'testcase':
|
||||
{'name': 'certificate verify fails',
|
||||
|
@ -209,7 +208,7 @@ tests = [
|
|||
'client_sigalgs': None},
|
||||
'result':
|
||||
{'ret_success': False,
|
||||
'error_code': error_codes.ERR_CERTIFICATE_VERIFY_FAILED,
|
||||
'error_code': ssl.ERR_CERTIFICATE_VERIFY_FAILED,
|
||||
'exception': None}},
|
||||
{'testcase':
|
||||
{'name': 'no matching curve',
|
||||
|
@ -232,7 +231,7 @@ tests = [
|
|||
'client_sigalgs': None},
|
||||
'result':
|
||||
{'ret_success': False,
|
||||
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
|
||||
'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
|
||||
'exception': None}},
|
||||
{'testcase':
|
||||
{'name': 'matching curve',
|
||||
|
@ -278,7 +277,7 @@ tests = [
|
|||
'client_sigalgs': None},
|
||||
'result':
|
||||
{'ret_success': False,
|
||||
'error_code': error_codes.ERR_PORT_UNREACHABLE,
|
||||
'error_code': ssl.ERR_PORT_UNREACHABLE,
|
||||
'exception': None}},
|
||||
{'testcase':
|
||||
{'name': 'no matching sigalgs',
|
||||
|
@ -301,7 +300,7 @@ tests = [
|
|||
'client_sigalgs': "RSA+SHA256"},
|
||||
'result':
|
||||
{'ret_success': False,
|
||||
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
|
||||
'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
|
||||
'exception': None}},
|
||||
{'testcase':
|
||||
{'name': 'matching sigalgs',
|
||||
|
@ -347,7 +346,7 @@ tests = [
|
|||
'client_sigalgs': None},
|
||||
'result':
|
||||
{'ret_success': False,
|
||||
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
|
||||
'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
|
||||
'exception': None}},
|
||||
{'testcase':
|
||||
{'name': 'matching cipher',
|
||||
|
@ -493,13 +492,8 @@ class TestSequenceMeta(type):
|
|||
class WrapperTests(unittest.TestCase):
|
||||
__metaclass__ = TestSequenceMeta
|
||||
|
||||
def setUp(self):
|
||||
super(WrapperTests, self).setUp()
|
||||
|
||||
do_patch()
|
||||
|
||||
def test_build_cert_chain(self):
|
||||
steps = [SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT]
|
||||
steps = [ssl.SSL_BUILD_CHAIN_FLAG_NONE, ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT]
|
||||
chatty, connectionchatty = CHATTY, CHATTY_CLIENT
|
||||
indata = 'FOO'
|
||||
certs = dict()
|
||||
|
|
145
dtls/wrapper.py
145
dtls/wrapper.py
|
@ -38,12 +38,36 @@ import socket
|
|||
from patch import do_patch
|
||||
do_patch()
|
||||
from sslconnection import SSLContext, SSL
|
||||
from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \
|
||||
SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK
|
||||
import err as err_codes
|
||||
|
||||
_logger = getLogger(__name__)
|
||||
|
||||
|
||||
def client(sock, keyfile=None, certfile=None,
|
||||
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None,
|
||||
do_handshake_on_connect=True, suppress_ragged_eofs=True,
|
||||
ciphers=None, curves=None, sigalgs=None, user_mtu=None):
|
||||
|
||||
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False,
|
||||
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
|
||||
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
|
||||
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
|
||||
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE)
|
||||
|
||||
|
||||
def server(sock, keyfile=None, certfile=None,
|
||||
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None,
|
||||
do_handshake_on_connect=False, suppress_ragged_eofs=True,
|
||||
ciphers=None, curves=None, sigalgs=None, user_mtu=None,
|
||||
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
|
||||
|
||||
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True,
|
||||
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
|
||||
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
|
||||
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
|
||||
server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options)
|
||||
|
||||
|
||||
class DtlsSocket(object):
|
||||
|
||||
class _ClientSession(object):
|
||||
|
@ -57,7 +81,7 @@ class DtlsSocket(object):
|
|||
return self.host, self.port
|
||||
|
||||
def __init__(self,
|
||||
peerOrSock,
|
||||
sock=None,
|
||||
keyfile=None,
|
||||
certfile=None,
|
||||
server_side=False,
|
||||
|
@ -71,13 +95,12 @@ class DtlsSocket(object):
|
|||
sigalgs=None,
|
||||
user_mtu=None,
|
||||
server_key_exchange_curve=None,
|
||||
server_cert_options=SSL_BUILD_CHAIN_FLAG_NONE):
|
||||
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
|
||||
|
||||
if server_cert_options is None:
|
||||
server_cert_options = SSL_BUILD_CHAIN_FLAG_NONE
|
||||
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
|
||||
|
||||
self._ssl_logging = False
|
||||
self._peer = None
|
||||
self._server_side = server_side
|
||||
self._ciphers = ciphers
|
||||
self._curves = curves
|
||||
|
@ -87,15 +110,11 @@ class DtlsSocket(object):
|
|||
self._server_cert_options = server_cert_options
|
||||
|
||||
# Default socket creation
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
if isinstance(peerOrSock, tuple):
|
||||
# Address tuple
|
||||
self._peer = peerOrSock
|
||||
else:
|
||||
# Socket, use given
|
||||
sock = peerOrSock
|
||||
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
if isinstance(sock, socket.socket):
|
||||
_sock = sock
|
||||
|
||||
self._sock = ssl.wrap_socket(sock,
|
||||
self._sock = ssl.wrap_socket(_sock,
|
||||
keyfile=keyfile,
|
||||
certfile=certfile,
|
||||
server_side=self._server_side,
|
||||
|
@ -112,13 +131,6 @@ class DtlsSocket(object):
|
|||
self._clients = {}
|
||||
self._timeout = None
|
||||
|
||||
if self._peer:
|
||||
self._sock.bind(self._peer)
|
||||
self._sock.listen(0)
|
||||
else:
|
||||
if self._peer:
|
||||
self._sock.connect(self._peer)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if hasattr(self, "_sock") and hasattr(self._sock, item):
|
||||
return getattr(self._sock, item)
|
||||
|
@ -159,17 +171,12 @@ class DtlsSocket(object):
|
|||
for cli in self._clients.keys():
|
||||
cli.close()
|
||||
else:
|
||||
try:
|
||||
self._sock.unwrap()
|
||||
except:
|
||||
pass
|
||||
self._sock.close()
|
||||
|
||||
def write(self, data):
|
||||
# return self._sock.write(data)
|
||||
return self.sendto(data, self._peer)
|
||||
|
||||
def read(self, len=1024):
|
||||
# return self._sock.read(len=len)
|
||||
return self.recvfrom(len)[0]
|
||||
|
||||
def recvfrom(self, bufsize, flags=0):
|
||||
if self._server_side:
|
||||
return self._recvfrom_on_server_side(bufsize, flags=flags)
|
||||
|
@ -180,11 +187,13 @@ class DtlsSocket(object):
|
|||
try:
|
||||
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
|
||||
|
||||
except socket.timeout as e_timeout:
|
||||
raise e_timeout
|
||||
except socket.timeout:
|
||||
# __Nothing__ received from any client
|
||||
raise socket.timeout
|
||||
|
||||
try:
|
||||
for conn in r: # type: ssl.SSLSocket
|
||||
for conn in r:
|
||||
_last_peer = conn.getpeername() if conn._connected else None
|
||||
if self._sockIsServerSock(conn):
|
||||
# Connect
|
||||
self._clientAccept(conn)
|
||||
|
@ -195,38 +204,43 @@ class DtlsSocket(object):
|
|||
# Normal read
|
||||
else:
|
||||
buf = self._clientRead(conn, bufsize)
|
||||
if buf and conn in self._clients:
|
||||
if buf:
|
||||
if conn in self._clients:
|
||||
return buf, self._clients[conn].getAddr()
|
||||
else:
|
||||
_logger.debug('Received data from an already disconnected client!')
|
||||
|
||||
except Exception as e:
|
||||
setattr(e, 'peer', _last_peer)
|
||||
raise e
|
||||
|
||||
try:
|
||||
for conn in self._getClientReadingSockets():
|
||||
if conn.get_timeout():
|
||||
conn.handle_timeout()
|
||||
ret = conn.handle_timeout()
|
||||
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# __No_data__ received from any client
|
||||
raise socket.timeout
|
||||
|
||||
def _recvfrom_on_client_side(self, bufsize, flags):
|
||||
try:
|
||||
buf = self._sock.recv(bufsize, flags)
|
||||
|
||||
except ssl.SSLError as e_ssl:
|
||||
if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
|
||||
return '', self._peer
|
||||
elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
|
||||
raise e_ssl
|
||||
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
|
||||
except ssl.SSLError as e:
|
||||
if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
else:
|
||||
if buf:
|
||||
return buf, self._peer
|
||||
return buf, self._sock.getpeername()
|
||||
|
||||
# __No_data__ received from any client
|
||||
raise socket.timeout
|
||||
|
||||
def sendto(self, buf, address):
|
||||
|
@ -242,19 +256,13 @@ class DtlsSocket(object):
|
|||
return 0
|
||||
|
||||
def _sendto_from_client_side(self, buf, address):
|
||||
while True:
|
||||
try:
|
||||
if not self._sock._connected:
|
||||
self._sock.connect(address)
|
||||
bytes_sent = self._sock.send(buf)
|
||||
|
||||
except ssl.SSLError as e_ssl:
|
||||
if str(e_ssl).startswith("503:"):
|
||||
# The write operation timed out
|
||||
continue
|
||||
raise e_ssl
|
||||
|
||||
else:
|
||||
if bytes_sent:
|
||||
break
|
||||
except ssl.SSLError as e:
|
||||
raise e
|
||||
|
||||
return bytes_sent
|
||||
|
||||
|
@ -278,14 +286,15 @@ class DtlsSocket(object):
|
|||
ret = conn.accept()
|
||||
_logger.debug('Accept returned with ... %s' % (str(ret)))
|
||||
|
||||
except Exception as e_accept:
|
||||
raise e_accept
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
else:
|
||||
if ret:
|
||||
client, addr = ret
|
||||
host, port = addr
|
||||
if client in self._clients:
|
||||
_logger.debug('Client already connected %s' % str(client))
|
||||
raise ValueError
|
||||
self._clients[client] = self._ClientSession(host=host, port=port)
|
||||
|
||||
|
@ -297,17 +306,16 @@ class DtlsSocket(object):
|
|||
|
||||
try:
|
||||
conn.do_handshake()
|
||||
_logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr())))
|
||||
_logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr())))
|
||||
|
||||
self._clients[conn].handshake_done = True
|
||||
|
||||
except ssl.SSLError as e_handshake:
|
||||
if str(e_handshake).startswith("504:"):
|
||||
pass
|
||||
elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ:
|
||||
except ssl.SSLError as e:
|
||||
if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
|
||||
pass
|
||||
else:
|
||||
raise e_handshake
|
||||
self._clientDrop(conn, error=e)
|
||||
raise e
|
||||
|
||||
def _clientRead(self, conn, bufsize=4096):
|
||||
_logger.debug('*' * 60)
|
||||
|
@ -317,13 +325,11 @@ class DtlsSocket(object):
|
|||
ret = conn.recv(bufsize)
|
||||
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
|
||||
|
||||
except ssl.SSLError as e_read:
|
||||
if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
|
||||
self._clientDrop(conn)
|
||||
elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
|
||||
self._clientDrop(conn, error=e_read)
|
||||
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
|
||||
except ssl.SSLError as e:
|
||||
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
|
||||
pass
|
||||
else:
|
||||
self._clientDrop(conn, error=e)
|
||||
|
||||
return ret
|
||||
|
||||
|
@ -338,8 +344,8 @@ class DtlsSocket(object):
|
|||
ret = conn.send(_data)
|
||||
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
|
||||
|
||||
except Exception as e_write:
|
||||
raise e_write
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return ret
|
||||
|
||||
|
@ -354,8 +360,11 @@ class DtlsSocket(object):
|
|||
|
||||
if conn in self._clients:
|
||||
del self._clients[conn]
|
||||
try:
|
||||
conn.unwrap()
|
||||
except:
|
||||
pass
|
||||
conn.close()
|
||||
|
||||
except Exception as e_drop:
|
||||
except Exception as e:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue