Added an interface in SSLConnection() to access SSLContext() and SSL() for manipulating settings during creation

* dtls/openssl.py:
	- Added utility fucntions EC_curve_nist2nid() and EC_curve_nid2nist()
* dtls/patch.py:
	- Extended wrap_socket() arguments with callbacks for user config functions of ssl context and ssl session values
	- Extended SSLSocket() arguments with callbacks for user config functions of ssl context and ssl session values
* dtls/sslconnection.py:
	- Extended SSLConnection() arguments with callbacks for user config functions of ssl context and ssl session values
	- During the init of client and server the corresponding user config functions are called (if given)
	- Added new classes SSLContext() [set_ciphers(), set_sigalgs(), set_curves(), set_ecdh_curve(), build_cert_chain(),
	set_ssl_logging()] and SSL() [set_mtu(), set_link_mtu()]
incoming
mcfreis 2017-03-20 16:00:11 +01:00
parent f5b88155fd
commit 60f76fac83
4 changed files with 289 additions and 113 deletions

View File

@ -1,3 +1,18 @@
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added an interface in SSLConnection() to access SSLContext() and SSL() for manipulating settings during creation
* dtls/openssl.py:
- Added utility fucntions EC_curve_nist2nid() and EC_curve_nid2nist()
* dtls/patch.py:
- Extended wrap_socket() arguments with callbacks for user config functions of ssl context and ssl session values
- Extended SSLSocket() arguments with callbacks for user config functions of ssl context and ssl session values
* dtls/sslconnection.py:
- Extended SSLConnection() arguments with callbacks for user config functions of ssl context and ssl session values
- During the init of client and server the corresponding user config functions are called (if given)
- Added new classes SSLContext() [set_ciphers(), set_sigalgs(), set_curves(), set_ecdh_curve(), build_cert_chain(),
set_ssl_logging()] and SSL() [set_mtu(), set_link_mtu()]
2017-03-17 Björn Freise <mcfreis@gmx.net> 2017-03-17 Björn Freise <mcfreis@gmx.net>
Added methods getting the curves supported by the runtime openSSL lib Added methods getting the curves supported by the runtime openSSL lib

View File

@ -803,6 +803,10 @@ map(lambda x: _make_function(*x), (
((EC_KEY, "ret"), (c_int, "nid"))), ((EC_KEY, "ret"), (c_int, "nid"))),
("EC_KEY_free", libcrypto, ("EC_KEY_free", libcrypto,
((None, "ret"), (EC_KEY, "key"))), ((None, "ret"), (EC_KEY, "key"))),
("EC_curve_nist2nid", libcrypto,
((c_int, "ret"), (POINTER(c_char), "name")), True, None),
("EC_curve_nid2nist", libcrypto,
((c_char_p, "ret"), (c_int, "nid")), True, None),
)) ))
# #

View File

@ -74,14 +74,19 @@ def _wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None): suppress_ragged_eofs=True,
ciphers=None,
cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile, return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs, server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers) ciphers=ciphers,
cb_user_config_ssl_ctx=cb_user_config_ssl_ctx,
cb_user_config_ssl=cb_user_config_ssl)
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
"""Retrieve a server certificate """Retrieve a server certificate
@ -115,7 +120,9 @@ def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
server_hostname=None, server_hostname=None,
_context=None): _context=None,
cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
is_connection = is_datagram = False is_connection = is_datagram = False
if isinstance(sock, SSLConnection): if isinstance(sock, SSLConnection):
is_connection = True is_connection = True
@ -167,7 +174,9 @@ def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
server_side, cert_reqs, server_side, cert_reqs,
ssl_version, ca_certs, ssl_version, ca_certs,
do_handshake_on_connect, do_handshake_on_connect,
suppress_ragged_eofs, ciphers) suppress_ragged_eofs, ciphers,
cb_user_config_ssl_ctx=cb_user_config_ssl_ctx,
cb_user_config_ssl=cb_user_config_ssl)
else: else:
self._connected = True self._connected = True
self._sslobj = sock self._sslobj = sock
@ -185,6 +194,8 @@ def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
self.do_handshake_on_connect = do_handshake_on_connect self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs self.suppress_ragged_eofs = suppress_ragged_eofs
self._makefile_refs = 0 self._makefile_refs = 0
self._user_config_ssl_ctx = cb_user_config_ssl_ctx
self._user_config_ssl = cb_user_config_ssl
# Perform method substitution and addition (without reference cycle) # Perform method substitution and addition (without reference cycle)
self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self)) self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self))
@ -203,7 +214,9 @@ def _SSLSocket_listen(self, ignored):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers) self.suppress_ragged_eofs, self.ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
def _SSLSocket_accept(self): def _SSLSocket_accept(self):
if self._connected: if self._connected:
@ -218,7 +231,9 @@ def _SSLSocket_accept(self):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers) self.suppress_ragged_eofs, self.ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
return new_ssl_sock, addr return new_ssl_sock, addr
def _SSLSocket_real_connect(self, addr, return_errno): def _SSLSocket_real_connect(self, addr, return_errno):
@ -229,7 +244,9 @@ def _SSLSocket_real_connect(self, addr, return_errno):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers) self.suppress_ragged_eofs, self.ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
try: try:
self._sslobj.connect(addr) self._sslobj.connect(addr)
except socket_error as e: except socket_error as e:

View File

@ -177,6 +177,123 @@ class _CallbackProxy(object):
return self.ssl_func(self.ssl_connection, *args, **kwargs) return self.ssl_func(self.ssl_connection, *args, **kwargs)
class SSLContext(object):
def __init__(self, ctx):
self._ctx = ctx
def set_ciphers(self, ciphers):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html
:param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers)
return retVal
def set_sigalgs(self, sigalgs):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html
:param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs)
return retVal
def set_curves(self, curves):
u''' Set supported curves by name, nid or nist.
:param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ...
:return: 1 for success and 0 for failure
'''
retVal = None
if isinstance(curves, str):
retVal = SSL_CTX_set1_curves_list(self._ctx, curves)
elif isinstance(curves, tuple):
retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves))
return retVal
@staticmethod
def get_ec_nist2nid(nist):
if not isinstance(nist, tuple):
nist = nist.split(":")
nid = tuple(EC_curve_nist2nid(x) for x in nist)
return nid
@staticmethod
def get_ec_nid2nist(nid):
if not isinstance(nid, tuple):
nid = (nid, )
nist = ":".join([EC_curve_nid2nist(x) for x in nid])
return nist
@staticmethod
def get_ec_available(bAsName=True):
curves = get_elliptic_curves()
return sorted([x.name for x in curves] if bAsName else [x.nid for x in curves])
def set_ecdh_curve(self, curve_name=None):
u'''
Used for server only!
s.a. openssl.exe ecparam -list_curves
:param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ...
:return: 1 for success and 0 for failure
'''
if curve_name:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0)
avail_curves = get_elliptic_curves()
self._ctx.key = [curve for curve in avail_curves if curve.name == curve_name][0].to_EC_KEY()
retVal = SSL_CTX_set_tmp_ecdh(self._ctx, self._ctx.key)
else:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1)
return retVal
def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NO_ROOT):
u'''
Used for server side only!
:param flags:
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_build_cert_chain(self._ctx, flags)
return retVal
def set_ssl_logging(self, enable=False, func=_ssl_logging_cb):
u''' Enable or disable SSL logging
:param True | False enable: Enable or disable SSL logging
:param func: Callback function for logging
'''
if enable:
SSL_CTX_set_info_callback(self._ctx, func)
else:
SSL_CTX_set_info_callback(self._ctx, 0)
class SSL(object):
def __init__(self, ssl):
self._ssl = ssl
def set_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
SSL_set_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
def set_link_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
class SSLConnection(object): class SSLConnection(object):
"""DTLS peer association """DTLS peer association
@ -210,6 +327,7 @@ class SSLConnection(object):
elif self._ssl_version == PROTOCOL_DTLSv1: elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method())) self._ctx = _CTX(SSL_CTX_new(server_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value)
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
@ -229,6 +347,7 @@ class SSLConnection(object):
_CallbackProxy(self._generate_cookie_cb), _CallbackProxy(self._generate_cookie_cb),
_CallbackProxy(self._verify_cookie_cb)) _CallbackProxy(self._verify_cookie_cb))
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_accept_state(self._ssl.value) SSL_set_accept_state(self._ssl.value)
if peer_address and self._do_handshake_on_connect: if peer_address and self._do_handshake_on_connect:
return lambda: self.do_handshake() return lambda: self.do_handshake()
@ -245,12 +364,14 @@ class SSLConnection(object):
elif self._ssl_version == PROTOCOL_DTLSv1: elif self._ssl_version == PROTOCOL_DTLSv1:
client_method = DTLSv1_client_method client_method = DTLSv1_client_method
self._ctx = _CTX(SSL_CTX_new(client_method())) self._ctx = _CTX(SSL_CTX_new(client_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
else: else:
verify_mode = SSL_VERIFY_PEER verify_mode = SSL_VERIFY_PEER
self._config_ssl_ctx(verify_mode) self._config_ssl_ctx(verify_mode)
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_connect_state(self._ssl.value) SSL_set_connect_state(self._ssl.value)
if peer_address: if peer_address:
return lambda: self.connect(peer_address) return lambda: self.connect(peer_address)
@ -273,7 +394,8 @@ class SSLConnection(object):
SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers)
except openssl_error() as err: except openssl_error() as err:
raise_ssl_error(ERR_NO_CIPHER, err) raise_ssl_error(ERR_NO_CIPHER, err)
SSL_CTX_set_info_callback(self._ctx.value, _ssl_logging_cb) if self._user_config_ssl_ctx:
self._user_config_ssl_ctx(self._intf_ssl_ctx)
def _copy_server(self): def _copy_server(self):
source = self._sock source = self._sock
@ -299,7 +421,10 @@ class SSLConnection(object):
BIO_dgram_set_connected(self._wbio.value, BIO_dgram_set_connected(self._wbio.value,
source._pending_peer_address) source._pending_peer_address)
source._ssl = _SSL(SSL_new(self._ctx.value)) source._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(source._ssl.value)
SSL_set_accept_state(source._ssl.value) SSL_set_accept_state(source._ssl.value)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
source._rbio = new_source_rbio source._rbio = new_source_rbio
source._wbio = new_source_wbio source._wbio = new_source_wbio
SSL_set_bio(source._ssl.value, SSL_set_bio(source._ssl.value,
@ -318,7 +443,10 @@ class SSLConnection(object):
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
BIO_dgram_set_peer(self._wbio.value, source._peer_address) BIO_dgram_set_peer(self._wbio.value, source._peer_address)
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_accept_state(self._ssl.value) SSL_set_accept_state(self._ssl.value)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
if self._do_handshake_on_connect: if self._do_handshake_on_connect:
return lambda: self.do_handshake() return lambda: self.do_handshake()
@ -378,7 +506,9 @@ class SSLConnection(object):
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None): suppress_ragged_eofs=True, ciphers=None,
cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
"""Constructor """Constructor
Arguments: Arguments:
@ -408,6 +538,11 @@ class SSLConnection(object):
self._handshake_done = False self._handshake_done = False
self._wbio_nb = self._rbio_nb = False self._wbio_nb = self._rbio_nb = False
self._user_config_ssl_ctx = cb_user_config_ssl_ctx
self._intf_ssl_ctx = None
self._user_config_ssl = cb_user_config_ssl
self._intf_ssl = None
if isinstance(sock, SSLConnection): if isinstance(sock, SSLConnection):
post_init = self._copy_server() post_init = self._copy_server()
elif isinstance(sock, _UnwrappedSocket): elif isinstance(sock, _UnwrappedSocket):
@ -424,6 +559,9 @@ class SSLConnection(object):
SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU) SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl.value, 1500) DTLS_set_link_mtu(self._ssl.value, 1500)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value)
self._rbio.disown() self._rbio.disown()
self._wbio.disown() self._wbio.disown()
@ -535,7 +673,9 @@ class SSLConnection(object):
new_conn = SSLConnection(self, self._keyfile, self._certfile, True, new_conn = SSLConnection(self, self._keyfile, self._certfile, True,
self._cert_reqs, self._ssl_version, self._cert_reqs, self._ssl_version,
self._ca_certs, self._do_handshake_on_connect, self._ca_certs, self._do_handshake_on_connect,
self._suppress_ragged_eofs, self._ciphers) self._suppress_ragged_eofs, self._ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
new_peer = self._pending_peer_address new_peer = self._pending_peer_address
self._pending_peer_address = None self._pending_peer_address = None
if self._do_handshake_on_connect: if self._do_handshake_on_connect: