Merge pull request #3 from mcfreis/master

Update "Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()"
incoming
mcfreis 2017-03-23 14:14:33 +01:00 committed by GitHub
commit 418655bed0
6 changed files with 166 additions and 133 deletions

View File

@ -1,3 +1,17 @@
2017-03-23 Björn Freise <mcfreis@gmx.net>
Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()
* dtls/__init__.py: Added DtlsSocket() from wrapper and aliases for wrap_server() and wrap_client()
* dtls/err.py: Added patch_ssl_errors() to patch ssl-Module with ERR_* constants
* dtls/patch.py: Patched ssl-Module with SSL_BUILD_* constants and added call to patch_ssl_errors()
* dtls/wrapper.py:
- Added a server and client function to alias/wrap DtlsSocket() creation
- Cleanup of DtlsSocket.__init__()
- Cleanup of exception handling in all member methods
- Cleanup sendto() from client: no endless loop and first do a connect if not already connected
* dtls/test/unit_wrapper.py: Adopt the changes made described above
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added a wrapper for a DTLS-Socket either as client or server - including unit tests

View File

@ -53,12 +53,12 @@ def _prep_bins():
for prebuilt_file in files:
try:
copy(path.join(prebuilt_path, prebuilt_file), package_root)
except IOError:
pass
_prep_bins() # prepare before module imports
from patch import do_patch
from sslconnection import SSLContext, SSL, SSLConnection
from demux import force_routing_demux, reset_default_demux
import err as error_codes
except IOError:
pass
_prep_bins() # prepare before module imports
from patch import do_patch
from sslconnection import SSLContext, SSL, SSLConnection
from demux import force_routing_demux, reset_default_demux
from wrapper import DtlsSocket, client as wrap_client, server as wrap_server

View File

@ -42,24 +42,31 @@ SSL_ERROR_WANT_ACCEPT = 8
ERR_BOTH_KEY_CERT_FILES = 500
ERR_BOTH_KEY_CERT_FILES_SVR = 298
ERR_NO_CERTS = 331
ERR_NO_CIPHER = 501
ERR_READ_TIMEOUT = 502
ERR_WRITE_TIMEOUT = 503
ERR_HANDSHAKE_TIMEOUT = 504
ERR_PORT_UNREACHABLE = 505
ERR_WRONG_SSL_VERSION = 0x1409210A
ERR_WRONG_VERSION_NUMBER = 0x1408A10B
ERR_COOKIE_MISMATCH = 0x1408A134
ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086
ERR_NO_SHARED_CIPHER = 0x1408A0C1
ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5
class SSLError(socket_error):
"""This exception is raised by modules in the dtls package."""
def __init__(self, *args):
super(SSLError, self).__init__(*args)
ERR_NO_CIPHER = 501
ERR_READ_TIMEOUT = 502
ERR_WRITE_TIMEOUT = 503
ERR_HANDSHAKE_TIMEOUT = 504
ERR_PORT_UNREACHABLE = 505
ERR_WRONG_SSL_VERSION = 0x1409210A
ERR_WRONG_VERSION_NUMBER = 0x1408A10B
ERR_COOKIE_MISMATCH = 0x1408A134
ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086
ERR_NO_SHARED_CIPHER = 0x1408A0C1
ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5
ERR_TLSV1_ALERT_UNKNOWN_CA = 0x14102418
def patch_ssl_errors():
import ssl
errors = [i for i in globals().iteritems() if type(i[1]) == int and str(i[0]).startswith('ERR_')]
for k, v in errors:
if not hasattr(ssl, k):
setattr(ssl, k, v)
class SSLError(socket_error):
"""This exception is raised by modules in the dtls package."""
def __init__(self, *args):
super(SSLError, self).__init__(*args)
class InvalidSocketError(Exception):

View File

@ -43,7 +43,9 @@ import errno
from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error
from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \
SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK, SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR, SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR
from err import raise_as_ssl_module_error, patch_ssl_errors
def do_patch():
@ -64,10 +66,17 @@ def do_patch():
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
ssl.SSL_BUILD_CHAIN_FLAG_NONE = SSL_BUILD_CHAIN_FLAG_NONE
ssl.SSL_BUILD_CHAIN_FLAG_UNTRUSTED = SSL_BUILD_CHAIN_FLAG_UNTRUSTED
ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT = SSL_BUILD_CHAIN_FLAG_NO_ROOT
ssl.SSL_BUILD_CHAIN_FLAG_CHECK = SSL_BUILD_CHAIN_FLAG_CHECK
ssl.SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR = SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR
ssl.SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR = SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR
_orig_SSLSocket_init = ssl.SSLSocket.__init__
_orig_get_server_certificate = ssl.get_server_certificate
ssl.SSLSocket.__init__ = _SSLSocket_init
ssl.get_server_certificate = _get_server_certificate
patch_ssl_errors()
raise_as_ssl_module_error()
def _wrap_socket(sock, keyfile=None, certfile=None,
@ -198,24 +207,24 @@ def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
self._user_config_ssl = cb_user_config_ssl
# Perform method substitution and addition (without reference cycle)
self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self))
self.listen = MethodType(_SSLSocket_listen, proxy(self))
self.accept = MethodType(_SSLSocket_accept, proxy(self))
self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self))
self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self))
# Extra
self.getpeercertchain = MethodType(_getpeercertchain, proxy(self))
def _getpeercertchain(self, binary_form=False):
return self._sslobj.getpeercertchain(binary_form)
def _SSLSocket_listen(self, ignored):
if self._connected:
raise ValueError("attempt to listen on connected SSLSocket!")
if self._sslobj:
return
self._sslobj = SSLConnection(socket(_sock=self._sock),
self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self))
self.listen = MethodType(_SSLSocket_listen, proxy(self))
self.accept = MethodType(_SSLSocket_accept, proxy(self))
self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self))
self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self))
# Extra
self.getpeercertchain = MethodType(_getpeercertchain, proxy(self))
def _getpeercertchain(self, binary_form=False):
return self._sslobj.getpeercertchain(binary_form)
def _SSLSocket_listen(self, ignored):
if self._connected:
raise ValueError("attempt to listen on connected SSLSocket!")
if self._sslobj:
return
self._sslobj = SSLConnection(socket(_sock=self._sock),
self.keyfile, self.certfile, True,
self.cert_reqs, self.ssl_version,
self.ca_certs,

View File

@ -15,8 +15,7 @@ from logging import basicConfig, DEBUG, getLogger
_logger = getLogger(__name__)
import ssl
from dtls import do_patch, error_codes
from dtls.wrapper import DtlsSocket, SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT
from dtls import DtlsSocket
HOST = "localhost"
@ -186,7 +185,7 @@ tests = [
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_WRONG_SSL_VERSION,
'error_code': ssl.ERR_WRONG_SSL_VERSION,
'exception': None}},
{'testcase':
{'name': 'certificate verify fails',
@ -209,7 +208,7 @@ tests = [
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_CERTIFICATE_VERIFY_FAILED,
'error_code': ssl.ERR_CERTIFICATE_VERIFY_FAILED,
'exception': None}},
{'testcase':
{'name': 'no matching curve',
@ -232,7 +231,7 @@ tests = [
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}},
{'testcase':
{'name': 'matching curve',
@ -278,7 +277,7 @@ tests = [
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_PORT_UNREACHABLE,
'error_code': ssl.ERR_PORT_UNREACHABLE,
'exception': None}},
{'testcase':
{'name': 'no matching sigalgs',
@ -301,7 +300,7 @@ tests = [
'client_sigalgs': "RSA+SHA256"},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}},
{'testcase':
{'name': 'matching sigalgs',
@ -347,7 +346,7 @@ tests = [
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}},
{'testcase':
{'name': 'matching cipher',
@ -493,13 +492,8 @@ class TestSequenceMeta(type):
class WrapperTests(unittest.TestCase):
__metaclass__ = TestSequenceMeta
def setUp(self):
super(WrapperTests, self).setUp()
do_patch()
def test_build_cert_chain(self):
steps = [SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT]
steps = [ssl.SSL_BUILD_CHAIN_FLAG_NONE, ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT]
chatty, connectionchatty = CHATTY, CHATTY_CLIENT
indata = 'FOO'
certs = dict()

View File

@ -38,12 +38,36 @@ import socket
from patch import do_patch
do_patch()
from sslconnection import SSLContext, SSL
from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \
SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK
import err as err_codes
_logger = getLogger(__name__)
def client(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None,
do_handshake_on_connect=True, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE)
def server(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=False, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options)
class DtlsSocket(object):
class _ClientSession(object):
@ -57,7 +81,7 @@ class DtlsSocket(object):
return self.host, self.port
def __init__(self,
peerOrSock,
sock=None,
keyfile=None,
certfile=None,
server_side=False,
@ -71,13 +95,12 @@ class DtlsSocket(object):
sigalgs=None,
user_mtu=None,
server_key_exchange_curve=None,
server_cert_options=SSL_BUILD_CHAIN_FLAG_NONE):
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
if server_cert_options is None:
server_cert_options = SSL_BUILD_CHAIN_FLAG_NONE
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
self._ssl_logging = False
self._peer = None
self._server_side = server_side
self._ciphers = ciphers
self._curves = curves
@ -87,15 +110,11 @@ class DtlsSocket(object):
self._server_cert_options = server_cert_options
# Default socket creation
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if isinstance(peerOrSock, tuple):
# Address tuple
self._peer = peerOrSock
else:
# Socket, use given
sock = peerOrSock
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if isinstance(sock, socket.socket):
_sock = sock
self._sock = ssl.wrap_socket(sock,
self._sock = ssl.wrap_socket(_sock,
keyfile=keyfile,
certfile=certfile,
server_side=self._server_side,
@ -112,13 +131,6 @@ class DtlsSocket(object):
self._clients = {}
self._timeout = None
if self._peer:
self._sock.bind(self._peer)
self._sock.listen(0)
else:
if self._peer:
self._sock.connect(self._peer)
def __getattr__(self, item):
if hasattr(self, "_sock") and hasattr(self._sock, item):
return getattr(self._sock, item)
@ -159,17 +171,12 @@ class DtlsSocket(object):
for cli in self._clients.keys():
cli.close()
else:
self._sock.unwrap()
try:
self._sock.unwrap()
except:
pass
self._sock.close()
def write(self, data):
# return self._sock.write(data)
return self.sendto(data, self._peer)
def read(self, len=1024):
# return self._sock.read(len=len)
return self.recvfrom(len)[0]
def recvfrom(self, bufsize, flags=0):
if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags)
@ -180,11 +187,13 @@ class DtlsSocket(object):
try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout as e_timeout:
raise e_timeout
except socket.timeout:
# __Nothing__ received from any client
raise socket.timeout
try:
for conn in r: # type: ssl.SSLSocket
for conn in r:
_last_peer = conn.getpeername() if conn._connected else None
if self._sockIsServerSock(conn):
# Connect
self._clientAccept(conn)
@ -195,38 +204,43 @@ class DtlsSocket(object):
# Normal read
else:
buf = self._clientRead(conn, bufsize)
if buf and conn in self._clients:
return buf, self._clients[conn].getAddr()
if buf:
if conn in self._clients:
return buf, self._clients[conn].getAddr()
else:
_logger.debug('Received data from an already disconnected client!')
except Exception as e:
setattr(e, 'peer', _last_peer)
raise e
try:
for conn in self._getClientReadingSockets():
if conn.get_timeout():
conn.handle_timeout()
ret = conn.handle_timeout()
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
except Exception as e:
raise e
# __No_data__ received from any client
raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags):
try:
buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e_ssl:
if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
return '', self._peer
elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
raise e_ssl
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
except ssl.SSLError as e:
if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e
else:
if buf:
return buf, self._peer
return buf, self._sock.getpeername()
# __No_data__ received from any client
raise socket.timeout
def sendto(self, buf, address):
@ -242,19 +256,13 @@ class DtlsSocket(object):
return 0
def _sendto_from_client_side(self, buf, address):
while True:
try:
bytes_sent = self._sock.send(buf)
try:
if not self._sock._connected:
self._sock.connect(address)
bytes_sent = self._sock.send(buf)
except ssl.SSLError as e_ssl:
if str(e_ssl).startswith("503:"):
# The write operation timed out
continue
raise e_ssl
else:
if bytes_sent:
break
except ssl.SSLError as e:
raise e
return bytes_sent
@ -278,14 +286,15 @@ class DtlsSocket(object):
ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e_accept:
raise e_accept
except Exception as e:
raise e
else:
if ret:
client, addr = ret
host, port = addr
if client in self._clients:
_logger.debug('Client already connected %s' % str(client))
raise ValueError
self._clients[client] = self._ClientSession(host=host, port=port)
@ -297,17 +306,16 @@ class DtlsSocket(object):
try:
conn.do_handshake()
_logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr())))
_logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True
except ssl.SSLError as e_handshake:
if str(e_handshake).startswith("504:"):
pass
elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ:
except ssl.SSLError as e:
if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e_handshake
self._clientDrop(conn, error=e)
raise e
def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60)
@ -317,13 +325,11 @@ class DtlsSocket(object):
ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e_read:
if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
self._clientDrop(conn)
elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
self._clientDrop(conn, error=e_read)
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
except ssl.SSLError as e:
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
self._clientDrop(conn, error=e)
return ret
@ -338,8 +344,8 @@ class DtlsSocket(object):
ret = conn.send(_data)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e_write:
raise e_write
except Exception as e:
raise e
return ret
@ -354,8 +360,11 @@ class DtlsSocket(object):
if conn in self._clients:
del self._clients[conn]
conn.unwrap()
try:
conn.unwrap()
except:
pass
conn.close()
except Exception as e_drop:
except Exception as e:
pass