Added an interface in SSLConnection() to access SSLContext() and SSL() for manipulating settings during creation

* dtls/openssl.py:
	- Added utility fucntions EC_curve_nist2nid() and EC_curve_nid2nist()
* dtls/patch.py:
	- Extended wrap_socket() arguments with callbacks for user config functions of ssl context and ssl session values
	- Extended SSLSocket() arguments with callbacks for user config functions of ssl context and ssl session values
* dtls/sslconnection.py:
	- Extended SSLConnection() arguments with callbacks for user config functions of ssl context and ssl session values
	- During the init of client and server the corresponding user config functions are called (if given)
	- Added new classes SSLContext() [set_ciphers(), set_sigalgs(), set_curves(), set_ecdh_curve(), build_cert_chain(),
	set_ssl_logging()] and SSL() [set_mtu(), set_link_mtu()]
incoming
mcfreis 2017-03-20 16:00:11 +01:00
parent f5b88155fd
commit 60f76fac83
4 changed files with 289 additions and 113 deletions

View File

@ -1,3 +1,18 @@
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added an interface in SSLConnection() to access SSLContext() and SSL() for manipulating settings during creation
* dtls/openssl.py:
- Added utility fucntions EC_curve_nist2nid() and EC_curve_nid2nist()
* dtls/patch.py:
- Extended wrap_socket() arguments with callbacks for user config functions of ssl context and ssl session values
- Extended SSLSocket() arguments with callbacks for user config functions of ssl context and ssl session values
* dtls/sslconnection.py:
- Extended SSLConnection() arguments with callbacks for user config functions of ssl context and ssl session values
- During the init of client and server the corresponding user config functions are called (if given)
- Added new classes SSLContext() [set_ciphers(), set_sigalgs(), set_curves(), set_ecdh_curve(), build_cert_chain(),
set_ssl_logging()] and SSL() [set_mtu(), set_link_mtu()]
2017-03-17 Björn Freise <mcfreis@gmx.net> 2017-03-17 Björn Freise <mcfreis@gmx.net>
Added methods getting the curves supported by the runtime openSSL lib Added methods getting the curves supported by the runtime openSSL lib

View File

@ -803,6 +803,10 @@ map(lambda x: _make_function(*x), (
((EC_KEY, "ret"), (c_int, "nid"))), ((EC_KEY, "ret"), (c_int, "nid"))),
("EC_KEY_free", libcrypto, ("EC_KEY_free", libcrypto,
((None, "ret"), (EC_KEY, "key"))), ((None, "ret"), (EC_KEY, "key"))),
("EC_curve_nist2nid", libcrypto,
((c_int, "ret"), (POINTER(c_char), "name")), True, None),
("EC_curve_nid2nist", libcrypto,
((c_char_p, "ret"), (c_int, "nid")), True, None),
)) ))
# #

View File

@ -31,74 +31,79 @@ has the following effects:
* Direct instantiation of SSLSocket as well as instantiation through * Direct instantiation of SSLSocket as well as instantiation through
ssl.wrap_socket are supported ssl.wrap_socket are supported
* Invocation of the function get_server_certificate with a value of * Invocation of the function get_server_certificate with a value of
PROTOCOL_DTLSv1 for the parameter ssl_version is supported PROTOCOL_DTLSv1 for the parameter ssl_version is supported
""" """
from socket import socket, getaddrinfo, _delegate_methods, error as socket_error from socket import socket, getaddrinfo, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM
from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE
from types import MethodType from types import MethodType
from weakref import proxy from weakref import proxy
import errno import errno
from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2 from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error from err import raise_as_ssl_module_error
def do_patch(): def do_patch():
import ssl as _ssl # import to be avoided if ssl module is never patched import ssl as _ssl # import to be avoided if ssl module is never patched
global _orig_SSLSocket_init, _orig_get_server_certificate global _orig_SSLSocket_init, _orig_get_server_certificate
global ssl global ssl
ssl = _ssl ssl = _ssl
if hasattr(ssl, "PROTOCOL_DTLSv1"): if hasattr(ssl, "PROTOCOL_DTLSv1"):
return return
_orig_wrap_socket = ssl.wrap_socket _orig_wrap_socket = ssl.wrap_socket
ssl.wrap_socket = _wrap_socket ssl.wrap_socket = _wrap_socket
ssl.PROTOCOL_DTLS = PROTOCOL_DTLS ssl.PROTOCOL_DTLS = PROTOCOL_DTLS
ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1
ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2 ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2
ssl._PROTOCOL_NAMES[PROTOCOL_DTLS] = "DTLS" ssl._PROTOCOL_NAMES[PROTOCOL_DTLS] = "DTLS"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2" ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2"
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
_orig_SSLSocket_init = ssl.SSLSocket.__init__ _orig_SSLSocket_init = ssl.SSLSocket.__init__
_orig_get_server_certificate = ssl.get_server_certificate _orig_get_server_certificate = ssl.get_server_certificate
ssl.SSLSocket.__init__ = _SSLSocket_init ssl.SSLSocket.__init__ = _SSLSocket_init
ssl.get_server_certificate = _get_server_certificate ssl.get_server_certificate = _get_server_certificate
raise_as_ssl_module_error() raise_as_ssl_module_error()
def _wrap_socket(sock, keyfile=None, certfile=None, def _wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None): suppress_ragged_eofs=True,
ciphers=None,
return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile, cb_user_config_ssl_ctx=None,
server_side=server_side, cert_reqs=cert_reqs, cb_user_config_ssl=None):
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile,
suppress_ragged_eofs=suppress_ragged_eofs, server_side=server_side, cert_reqs=cert_reqs,
ciphers=ciphers) ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): suppress_ragged_eofs=suppress_ragged_eofs,
"""Retrieve a server certificate ciphers=ciphers,
cb_user_config_ssl_ctx=cb_user_config_ssl_ctx,
cb_user_config_ssl=cb_user_config_ssl)
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
"""Retrieve a server certificate
Retrieve the certificate from the server at the specified address, Retrieve the certificate from the server at the specified address,
and return it as a PEM-encoded string. and return it as a PEM-encoded string.
If 'ca_certs' is specified, validate the server cert against it. If 'ca_certs' is specified, validate the server cert against it.
If 'ssl_version' is specified, use it in the connection attempt. If 'ssl_version' is specified, use it in the connection attempt.
""" """
if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2): if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2):
return _orig_get_server_certificate(addr, ssl_version, ca_certs) return _orig_get_server_certificate(addr, ssl_version, ca_certs)
if ca_certs is not None: if ca_certs is not None:
cert_reqs = ssl.CERT_REQUIRED cert_reqs = ssl.CERT_REQUIRED
else: else:
cert_reqs = ssl.CERT_NONE cert_reqs = ssl.CERT_NONE
af = getaddrinfo(addr[0], addr[1])[0][0] af = getaddrinfo(addr[0], addr[1])[0][0]
s = ssl.wrap_socket(socket(af, SOCK_DGRAM), s = ssl.wrap_socket(socket(af, SOCK_DGRAM),
ssl_version=ssl_version, ssl_version=ssl_version,
@ -109,13 +114,15 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
return ssl.DER_cert_to_PEM_cert(dercert) return ssl.DER_cert_to_PEM_cert(dercert)
def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None, def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
server_hostname=None, server_hostname=None,
_context=None): _context=None,
cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
is_connection = is_datagram = False is_connection = is_datagram = False
if isinstance(sock, SSLConnection): if isinstance(sock, SSLConnection):
is_connection = True is_connection = True
@ -167,7 +174,9 @@ def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
server_side, cert_reqs, server_side, cert_reqs,
ssl_version, ca_certs, ssl_version, ca_certs,
do_handshake_on_connect, do_handshake_on_connect,
suppress_ragged_eofs, ciphers) suppress_ragged_eofs, ciphers,
cb_user_config_ssl_ctx=cb_user_config_ssl_ctx,
cb_user_config_ssl=cb_user_config_ssl)
else: else:
self._connected = True self._connected = True
self._sslobj = sock self._sslobj = sock
@ -185,6 +194,8 @@ def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
self.do_handshake_on_connect = do_handshake_on_connect self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs self.suppress_ragged_eofs = suppress_ragged_eofs
self._makefile_refs = 0 self._makefile_refs = 0
self._user_config_ssl_ctx = cb_user_config_ssl_ctx
self._user_config_ssl = cb_user_config_ssl
# Perform method substitution and addition (without reference cycle) # Perform method substitution and addition (without reference cycle)
self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self)) self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self))
@ -203,7 +214,9 @@ def _SSLSocket_listen(self, ignored):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers) self.suppress_ragged_eofs, self.ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
def _SSLSocket_accept(self): def _SSLSocket_accept(self):
if self._connected: if self._connected:
@ -218,7 +231,9 @@ def _SSLSocket_accept(self):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers) self.suppress_ragged_eofs, self.ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
return new_ssl_sock, addr return new_ssl_sock, addr
def _SSLSocket_real_connect(self, addr, return_errno): def _SSLSocket_real_connect(self, addr, return_errno):
@ -229,7 +244,9 @@ def _SSLSocket_real_connect(self, addr, return_errno):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers) self.suppress_ragged_eofs, self.ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
try: try:
self._sslobj.connect(addr) self._sslobj.connect(addr)
except socket_error as e: except socket_error as e:

View File

@ -174,12 +174,129 @@ class _CallbackProxy(object):
self.ssl_func = cbm.im_func self.ssl_func = cbm.im_func
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ssl_func(self.ssl_connection, *args, **kwargs) return self.ssl_func(self.ssl_connection, *args, **kwargs)
class SSLConnection(object): class SSLContext(object):
"""DTLS peer association
def __init__(self, ctx):
self._ctx = ctx
def set_ciphers(self, ciphers):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html
:param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers)
return retVal
def set_sigalgs(self, sigalgs):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html
:param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs)
return retVal
def set_curves(self, curves):
u''' Set supported curves by name, nid or nist.
:param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ...
:return: 1 for success and 0 for failure
'''
retVal = None
if isinstance(curves, str):
retVal = SSL_CTX_set1_curves_list(self._ctx, curves)
elif isinstance(curves, tuple):
retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves))
return retVal
@staticmethod
def get_ec_nist2nid(nist):
if not isinstance(nist, tuple):
nist = nist.split(":")
nid = tuple(EC_curve_nist2nid(x) for x in nist)
return nid
@staticmethod
def get_ec_nid2nist(nid):
if not isinstance(nid, tuple):
nid = (nid, )
nist = ":".join([EC_curve_nid2nist(x) for x in nid])
return nist
@staticmethod
def get_ec_available(bAsName=True):
curves = get_elliptic_curves()
return sorted([x.name for x in curves] if bAsName else [x.nid for x in curves])
def set_ecdh_curve(self, curve_name=None):
u'''
Used for server only!
s.a. openssl.exe ecparam -list_curves
:param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ...
:return: 1 for success and 0 for failure
'''
if curve_name:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0)
avail_curves = get_elliptic_curves()
self._ctx.key = [curve for curve in avail_curves if curve.name == curve_name][0].to_EC_KEY()
retVal = SSL_CTX_set_tmp_ecdh(self._ctx, self._ctx.key)
else:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1)
return retVal
def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NO_ROOT):
u'''
Used for server side only!
:param flags:
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_build_cert_chain(self._ctx, flags)
return retVal
def set_ssl_logging(self, enable=False, func=_ssl_logging_cb):
u''' Enable or disable SSL logging
:param True | False enable: Enable or disable SSL logging
:param func: Callback function for logging
'''
if enable:
SSL_CTX_set_info_callback(self._ctx, func)
else:
SSL_CTX_set_info_callback(self._ctx, 0)
class SSL(object):
def __init__(self, ssl):
self._ssl = ssl
def set_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
SSL_set_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
def set_link_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
class SSLConnection(object):
"""DTLS peer association
This class associates two DTLS peer instances, wrapping OpenSSL library This class associates two DTLS peer instances, wrapping OpenSSL library
state including SSL (struct ssl_st), SSL_CTX, and BIO instances. state including SSL (struct ssl_st), SSL_CTX, and BIO instances.
""" """
@ -210,6 +327,7 @@ class SSLConnection(object):
elif self._ssl_version == PROTOCOL_DTLSv1: elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method())) self._ctx = _CTX(SSL_CTX_new(server_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value)
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
@ -226,12 +344,13 @@ class SSLConnection(object):
self._pending_peer_address = None self._pending_peer_address = None
self._cb_keepalive = SSL_CTX_set_cookie_cb( self._cb_keepalive = SSL_CTX_set_cookie_cb(
self._ctx.value, self._ctx.value,
_CallbackProxy(self._generate_cookie_cb), _CallbackProxy(self._generate_cookie_cb),
_CallbackProxy(self._verify_cookie_cb)) _CallbackProxy(self._verify_cookie_cb))
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
SSL_set_accept_state(self._ssl.value) self._intf_ssl = SSL(self._ssl.value)
if peer_address and self._do_handshake_on_connect: SSL_set_accept_state(self._ssl.value)
return lambda: self.do_handshake() if peer_address and self._do_handshake_on_connect:
return lambda: self.do_handshake()
def _init_client(self, peer_address): def _init_client(self, peer_address):
if self._sock.type != socket.SOCK_DGRAM: if self._sock.type != socket.SOCK_DGRAM:
@ -245,15 +364,17 @@ class SSLConnection(object):
elif self._ssl_version == PROTOCOL_DTLSv1: elif self._ssl_version == PROTOCOL_DTLSv1:
client_method = DTLSv1_client_method client_method = DTLSv1_client_method
self._ctx = _CTX(SSL_CTX_new(client_method())) self._ctx = _CTX(SSL_CTX_new(client_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
else: else:
verify_mode = SSL_VERIFY_PEER verify_mode = SSL_VERIFY_PEER
self._config_ssl_ctx(verify_mode) self._config_ssl_ctx(verify_mode)
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
SSL_set_connect_state(self._ssl.value) self._intf_ssl = SSL(self._ssl.value)
if peer_address: SSL_set_connect_state(self._ssl.value)
return lambda: self.connect(peer_address) if peer_address:
return lambda: self.connect(peer_address)
def _config_ssl_ctx(self, verify_mode): def _config_ssl_ctx(self, verify_mode):
SSL_CTX_set_verify(self._ctx.value, verify_mode) SSL_CTX_set_verify(self._ctx.value, verify_mode)
@ -273,7 +394,8 @@ class SSLConnection(object):
SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers)
except openssl_error() as err: except openssl_error() as err:
raise_ssl_error(ERR_NO_CIPHER, err) raise_ssl_error(ERR_NO_CIPHER, err)
SSL_CTX_set_info_callback(self._ctx.value, _ssl_logging_cb) if self._user_config_ssl_ctx:
self._user_config_ssl_ctx(self._intf_ssl_ctx)
def _copy_server(self): def _copy_server(self):
source = self._sock source = self._sock
@ -296,13 +418,16 @@ class SSLConnection(object):
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio self._rbio = self._wbio
new_source_rbio = new_source_wbio new_source_rbio = new_source_wbio
BIO_dgram_set_connected(self._wbio.value, BIO_dgram_set_connected(self._wbio.value,
source._pending_peer_address) source._pending_peer_address)
source._ssl = _SSL(SSL_new(self._ctx.value)) source._ssl = _SSL(SSL_new(self._ctx.value))
SSL_set_accept_state(source._ssl.value) self._intf_ssl = SSL(source._ssl.value)
source._rbio = new_source_rbio SSL_set_accept_state(source._ssl.value)
source._wbio = new_source_wbio if self._user_config_ssl:
SSL_set_bio(source._ssl.value, self._user_config_ssl(self._intf_ssl)
source._rbio = new_source_rbio
source._wbio = new_source_wbio
SSL_set_bio(source._ssl.value,
new_source_rbio.value, new_source_rbio.value,
new_source_wbio.value) new_source_wbio.value)
new_source_rbio.disown() new_source_rbio.disown()
@ -315,13 +440,16 @@ class SSLConnection(object):
self._rsock = source._rsock self._rsock = source._rsock
self._ctx = source._ctx self._ctx = source._ctx
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
BIO_dgram_set_peer(self._wbio.value, source._peer_address) BIO_dgram_set_peer(self._wbio.value, source._peer_address)
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
SSL_set_accept_state(self._ssl.value) self._intf_ssl = SSL(self._ssl.value)
if self._do_handshake_on_connect: SSL_set_accept_state(self._ssl.value)
return lambda: self.do_handshake() if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
if self._do_handshake_on_connect:
return lambda: self.do_handshake()
def _check_nbio(self): def _check_nbio(self):
timeout = self._sock.gettimeout() timeout = self._sock.gettimeout()
if self._wbio_nb != timeout is not None: if self._wbio_nb != timeout is not None:
@ -378,10 +506,12 @@ class SSLConnection(object):
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None): suppress_ragged_eofs=True, ciphers=None,
cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
"""Constructor """Constructor
Arguments: Arguments:
these arguments match the ones of the SSLSocket class in the these arguments match the ones of the SSLSocket class in the
standard library's ssl module standard library's ssl module
""" """
@ -405,12 +535,17 @@ class SSLConnection(object):
self._do_handshake_on_connect = do_handshake_on_connect self._do_handshake_on_connect = do_handshake_on_connect
self._suppress_ragged_eofs = suppress_ragged_eofs self._suppress_ragged_eofs = suppress_ragged_eofs
self._ciphers = ciphers self._ciphers = ciphers
self._handshake_done = False self._handshake_done = False
self._wbio_nb = self._rbio_nb = False self._wbio_nb = self._rbio_nb = False
if isinstance(sock, SSLConnection): self._user_config_ssl_ctx = cb_user_config_ssl_ctx
post_init = self._copy_server() self._intf_ssl_ctx = None
elif isinstance(sock, _UnwrappedSocket): self._user_config_ssl = cb_user_config_ssl
self._intf_ssl = None
if isinstance(sock, SSLConnection):
post_init = self._copy_server()
elif isinstance(sock, _UnwrappedSocket):
post_init = self._reconnect_unwrapped() post_init = self._reconnect_unwrapped()
else: else:
try: try:
@ -424,6 +559,9 @@ class SSLConnection(object):
SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU) SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl.value, 1500) DTLS_set_link_mtu(self._ssl.value, 1500)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value)
self._rbio.disown() self._rbio.disown()
self._wbio.disown() self._wbio.disown()
@ -535,10 +673,12 @@ class SSLConnection(object):
new_conn = SSLConnection(self, self._keyfile, self._certfile, True, new_conn = SSLConnection(self, self._keyfile, self._certfile, True,
self._cert_reqs, self._ssl_version, self._cert_reqs, self._ssl_version,
self._ca_certs, self._do_handshake_on_connect, self._ca_certs, self._do_handshake_on_connect,
self._suppress_ragged_eofs, self._ciphers) self._suppress_ragged_eofs, self._ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
new_peer = self._pending_peer_address new_peer = self._pending_peer_address
self._pending_peer_address = None self._pending_peer_address = None
if self._do_handshake_on_connect: if self._do_handshake_on_connect:
# Note that since that connection's socket was just created in its # Note that since that connection's socket was just created in its
# constructor, the following operation must be blocking; hence # constructor, the following operation must be blocking; hence
# handshake-on-connect can only be used with a routing demux if # handshake-on-connect can only be used with a routing demux if