Minor fixes and "hopefully" compatible to Ubuntu 16.04

* dtls/__init__.py: Removed wrapper import
* dtls/openssl.py: Fixed line endings to LF
* dtls/patch.py: Removed PROTOCOL_SSLv3 import and fixed line endings to LF
* dtls/sslconnection.py: Fixed line endings to LF
* dtls/test/certs/*_ec.pem: Fixed line endings to LF
* dtls/test/echo_seq.py: Fixed line endings to LF
* dtls/test/simple_client.py: Fixed line endings to LF
* dtls/test/unit.py: Fixed line endings to LF
* dtls/test/unit_wrapper.py: Corrected wrapper import and fixed line endings to LF
* dtls/util.py: Fixed line endings to LF
* dtls/wrapper.py: Corrected function naming to wrap_client() and wrap_server(); Fixed line endings to LF
* dtls/x509.py: Fixed line endings to LF
incoming
mcfreis 2017-03-28 07:59:03 +02:00
parent dade3b8213
commit 083554e9e0
17 changed files with 2311 additions and 2295 deletions

View File

@ -1,3 +1,20 @@
2017-03-28 Björn Freise <mcfreis@gmx.net>
Minor fixes and "hopefully" compatible to Ubuntu 16.04
* dtls/__init__.py: Removed wrapper import
* dtls/openssl.py: Fixed line endings to LF
* dtls/patch.py: Removed PROTOCOL_SSLv3 import and fixed line endings to LF
* dtls/sslconnection.py: Fixed line endings to LF
* dtls/test/certs/*_ec.pem: Fixed line endings to LF
* dtls/test/echo_seq.py: Fixed line endings to LF
* dtls/test/simple_client.py: Fixed line endings to LF
* dtls/test/unit.py: Fixed line endings to LF
* dtls/test/unit_wrapper.py: Corrected wrapper import and fixed line endings to LF
* dtls/util.py: Fixed line endings to LF
* dtls/wrapper.py: Corrected function naming to wrap_client() and wrap_server(); Fixed line endings to LF
* dtls/x509.py: Fixed line endings to LF
2017-03-23 Björn Freise <mcfreis@gmx.net>
Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()

View File

@ -61,4 +61,3 @@ _prep_bins() # prepare before module imports
from patch import do_patch
from sslconnection import SSLContext, SSL, SSLConnection
from demux import force_routing_demux, reset_default_demux
from wrapper import DtlsSocket, client as wrap_client, server as wrap_server

File diff suppressed because it is too large Load Diff

View File

@ -36,7 +36,7 @@ has the following effects:
from socket import socket, getaddrinfo, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM
from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE
from ssl import PROTOCOL_SSLv23, CERT_NONE
from types import MethodType
from weakref import proxy
import errno

View File

@ -45,30 +45,30 @@ import socket
import hmac
import datetime
from logging import getLogger
from os import urandom
from select import select
from weakref import proxy
from err import openssl_error, InvalidSocketError
from err import raise_ssl_error
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_SHARED_CIPHER
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR, ERR_NO_CERTS
from x509 import _X509, decode_cert
from tlock import tlock_init
from openssl import *
from util import _Rsrc, _BIO
_logger = getLogger(__name__)
PROTOCOL_DTLSv1 = 256
PROTOCOL_DTLSv1_2 = 258
PROTOCOL_DTLS = 259
CERT_NONE = 0
CERT_OPTIONAL = 1
CERT_REQUIRED = 2
from os import urandom
from select import select
from weakref import proxy
from err import openssl_error, InvalidSocketError
from err import raise_ssl_error
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_SHARED_CIPHER
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR, ERR_NO_CERTS
from x509 import _X509, decode_cert
from tlock import tlock_init
from openssl import *
from util import _Rsrc, _BIO
_logger = getLogger(__name__)
PROTOCOL_DTLSv1 = 256
PROTOCOL_DTLSv1_2 = 258
PROTOCOL_DTLS = 259
CERT_NONE = 0
CERT_OPTIONAL = 1
CERT_REQUIRED = 2
#
# One-time global OpenSSL library initialization
@ -83,64 +83,64 @@ DTLS_OPENSSL_VERSION_INFO = (
DTLS_OPENSSL_VERSION_NUMBER >> 20 & 0xFF, # minor
DTLS_OPENSSL_VERSION_NUMBER >> 12 & 0xFF, # fix
DTLS_OPENSSL_VERSION_NUMBER >> 4 & 0xFF, # patch
DTLS_OPENSSL_VERSION_NUMBER & 0xF) # status
def _ssl_logging_cb(conn, where, return_code):
_state = where & ~SSL_ST_MASK
state = "SSL"
if _state & SSL_ST_INIT == SSL_ST_INIT:
if _state & SSL_ST_RENEGOTIATE == SSL_ST_RENEGOTIATE:
state += "_renew"
else:
state += "_init"
elif _state & SSL_ST_CONNECT:
state += "_connect"
elif _state & SSL_ST_ACCEPT:
state += "_accept"
elif _state == 0:
if where & SSL_CB_HANDSHAKE_START:
state += "_handshake_start"
elif where & SSL_CB_HANDSHAKE_DONE:
state += "_handshake_done"
if where & SSL_CB_LOOP:
state += '_loop'
_logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn),
return_code))
elif where & SSL_CB_ALERT:
state += '_alert'
state += "_read" if where & SSL_CB_READ else "_write"
_logger.debug("%s:%s:%s" % (state,
SSL_alert_type_string_long(return_code),
SSL_alert_desc_string_long(return_code)))
elif where & SSL_CB_EXIT:
state += '_exit'
if return_code == 0:
_logger.debug("%s:%s:%d(failed)" % (state,
SSL_state_string_long(conn),
return_code))
elif return_code < 0:
_logger.debug("%s:%s:%d(error)" % (state,
SSL_state_string_long(conn),
return_code))
else:
_logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn),
return_code))
else:
_logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn),
return_code))
class _CTX(_Rsrc):
"""SSL_CTX wrapper"""
def __init__(self, value):
DTLS_OPENSSL_VERSION_NUMBER & 0xF) # status
def _ssl_logging_cb(conn, where, return_code):
_state = where & ~SSL_ST_MASK
state = "SSL"
if _state & SSL_ST_INIT == SSL_ST_INIT:
if _state & SSL_ST_RENEGOTIATE == SSL_ST_RENEGOTIATE:
state += "_renew"
else:
state += "_init"
elif _state & SSL_ST_CONNECT:
state += "_connect"
elif _state & SSL_ST_ACCEPT:
state += "_accept"
elif _state == 0:
if where & SSL_CB_HANDSHAKE_START:
state += "_handshake_start"
elif where & SSL_CB_HANDSHAKE_DONE:
state += "_handshake_done"
if where & SSL_CB_LOOP:
state += '_loop'
_logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn),
return_code))
elif where & SSL_CB_ALERT:
state += '_alert'
state += "_read" if where & SSL_CB_READ else "_write"
_logger.debug("%s:%s:%s" % (state,
SSL_alert_type_string_long(return_code),
SSL_alert_desc_string_long(return_code)))
elif where & SSL_CB_EXIT:
state += '_exit'
if return_code == 0:
_logger.debug("%s:%s:%d(failed)" % (state,
SSL_state_string_long(conn),
return_code))
elif return_code < 0:
_logger.debug("%s:%s:%d(error)" % (state,
SSL_state_string_long(conn),
return_code))
else:
_logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn),
return_code))
else:
_logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn),
return_code))
class _CTX(_Rsrc):
"""SSL_CTX wrapper"""
def __init__(self, value):
super(_CTX, self).__init__(value)
def __del__(self):
@ -174,130 +174,130 @@ class _CallbackProxy(object):
self.ssl_func = cbm.im_func
def __call__(self, *args, **kwargs):
return self.ssl_func(self.ssl_connection, *args, **kwargs)
class SSLContext(object):
def __init__(self, ctx):
self._ctx = ctx
def set_ciphers(self, ciphers):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html
:param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers)
return retVal
def set_sigalgs(self, sigalgs):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html
:param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs)
return retVal
def set_curves(self, curves):
u''' Set supported curves by name, nid or nist.
:param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ...
:return: 1 for success and 0 for failure
'''
retVal = None
if isinstance(curves, str):
retVal = SSL_CTX_set1_curves_list(self._ctx, curves)
elif isinstance(curves, tuple):
retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves))
return retVal
@staticmethod
def get_ec_nist2nid(nist):
if not isinstance(nist, tuple):
nist = nist.split(":")
nid = tuple(EC_curve_nist2nid(x) for x in nist)
return nid
@staticmethod
def get_ec_nid2nist(nid):
if not isinstance(nid, tuple):
nid = (nid, )
nist = ":".join([EC_curve_nid2nist(x) for x in nid])
return nist
@staticmethod
def get_ec_available(bAsName=True):
curves = get_elliptic_curves()
return sorted([x.name for x in curves] if bAsName else [x.nid for x in curves])
def set_ecdh_curve(self, curve_name=None):
u''' Select a curve to use for ECDH(E) key exchange or set it to auto mode
Used for server only!
s.a. openssl.exe ecparam -list_curves
:param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ...
:return: 1 for success and 0 for failure
'''
if curve_name:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0)
avail_curves = get_elliptic_curves()
key = [curve for curve in avail_curves if curve.name == curve_name][0].to_EC_KEY()
retVal &= SSL_CTX_set_tmp_ecdh(self._ctx, key)
else:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1)
return retVal
def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NONE):
u'''
Used for server side only!
:param flags:
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_build_cert_chain(self._ctx, flags)
return retVal
def set_ssl_logging(self, enable=False, func=_ssl_logging_cb):
u''' Enable or disable SSL logging
:param True | False enable: Enable or disable SSL logging
:param func: Callback function for logging
'''
if enable:
SSL_CTX_set_info_callback(self._ctx, func)
else:
SSL_CTX_set_info_callback(self._ctx, 0)
class SSL(object):
def __init__(self, ssl):
self._ssl = ssl
def set_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
SSL_set_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
def set_link_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
class SSLConnection(object):
"""DTLS peer association
return self.ssl_func(self.ssl_connection, *args, **kwargs)
class SSLContext(object):
def __init__(self, ctx):
self._ctx = ctx
def set_ciphers(self, ciphers):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html
:param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers)
return retVal
def set_sigalgs(self, sigalgs):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html
:param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs)
return retVal
def set_curves(self, curves):
u''' Set supported curves by name, nid or nist.
:param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ...
:return: 1 for success and 0 for failure
'''
retVal = None
if isinstance(curves, str):
retVal = SSL_CTX_set1_curves_list(self._ctx, curves)
elif isinstance(curves, tuple):
retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves))
return retVal
@staticmethod
def get_ec_nist2nid(nist):
if not isinstance(nist, tuple):
nist = nist.split(":")
nid = tuple(EC_curve_nist2nid(x) for x in nist)
return nid
@staticmethod
def get_ec_nid2nist(nid):
if not isinstance(nid, tuple):
nid = (nid, )
nist = ":".join([EC_curve_nid2nist(x) for x in nid])
return nist
@staticmethod
def get_ec_available(bAsName=True):
curves = get_elliptic_curves()
return sorted([x.name for x in curves] if bAsName else [x.nid for x in curves])
def set_ecdh_curve(self, curve_name=None):
u''' Select a curve to use for ECDH(E) key exchange or set it to auto mode
Used for server only!
s.a. openssl.exe ecparam -list_curves
:param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ...
:return: 1 for success and 0 for failure
'''
if curve_name:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0)
avail_curves = get_elliptic_curves()
key = [curve for curve in avail_curves if curve.name == curve_name][0].to_EC_KEY()
retVal &= SSL_CTX_set_tmp_ecdh(self._ctx, key)
else:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1)
return retVal
def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NONE):
u'''
Used for server side only!
:param flags:
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_build_cert_chain(self._ctx, flags)
return retVal
def set_ssl_logging(self, enable=False, func=_ssl_logging_cb):
u''' Enable or disable SSL logging
:param True | False enable: Enable or disable SSL logging
:param func: Callback function for logging
'''
if enable:
SSL_CTX_set_info_callback(self._ctx, func)
else:
SSL_CTX_set_info_callback(self._ctx, 0)
class SSL(object):
def __init__(self, ssl):
self._ssl = ssl
def set_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
SSL_set_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
def set_link_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
class SSLConnection(object):
"""DTLS peer association
This class associates two DTLS peer instances, wrapping OpenSSL library
state including SSL (struct ssl_st), SSL_CTX, and BIO instances.
"""
@ -319,19 +319,19 @@ class SSLConnection(object):
rsock = self._udp_demux.get_connection(None)
if rsock is self._sock:
self._rbio = self._wbio
else:
self._rsock = rsock
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
server_method = DTLS_server_method
if self._ssl_version == PROTOCOL_DTLSv1_2:
server_method = DTLSv1_2_server_method
elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value)
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE
else:
self._rsock = rsock
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
server_method = DTLS_server_method
if self._ssl_version == PROTOCOL_DTLSv1_2:
server_method = DTLSv1_2_server_method
elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value)
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE
elif self._cert_reqs == CERT_OPTIONAL:
verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE
else:
@ -345,37 +345,37 @@ class SSLConnection(object):
self._pending_peer_address = None
self._cb_keepalive = SSL_CTX_set_cookie_cb(
self._ctx.value,
_CallbackProxy(self._generate_cookie_cb),
_CallbackProxy(self._verify_cookie_cb))
self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_accept_state(self._ssl.value)
if peer_address and self._do_handshake_on_connect:
return lambda: self.do_handshake()
_CallbackProxy(self._generate_cookie_cb),
_CallbackProxy(self._verify_cookie_cb))
self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_accept_state(self._ssl.value)
if peer_address and self._do_handshake_on_connect:
return lambda: self.do_handshake()
def _init_client(self, peer_address):
if self._sock.type != socket.SOCK_DGRAM:
raise InvalidSocketError("sock must be of type SOCK_DGRAM")
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio
client_method = DTLSv1_2_client_method # no "any" exists, therefore use v1_2 (highest possible)
if self._ssl_version == PROTOCOL_DTLSv1_2:
client_method = DTLSv1_2_client_method
elif self._ssl_version == PROTOCOL_DTLSv1:
client_method = DTLSv1_client_method
self._ctx = _CTX(SSL_CTX_new(client_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value)
if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE
else:
verify_mode = SSL_VERIFY_PEER
self._config_ssl_ctx(verify_mode)
self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_connect_state(self._ssl.value)
if peer_address:
return lambda: self.connect(peer_address)
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio
client_method = DTLSv1_2_client_method # no "any" exists, therefore use v1_2 (highest possible)
if self._ssl_version == PROTOCOL_DTLSv1_2:
client_method = DTLSv1_2_client_method
elif self._ssl_version == PROTOCOL_DTLSv1:
client_method = DTLSv1_client_method
self._ctx = _CTX(SSL_CTX_new(client_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value)
if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE
else:
verify_mode = SSL_VERIFY_PEER
self._config_ssl_ctx(verify_mode)
self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_connect_state(self._ssl.value)
if peer_address:
return lambda: self.connect(peer_address)
def _config_ssl_ctx(self, verify_mode):
SSL_CTX_set_verify(self._ctx.value, verify_mode)
@ -392,14 +392,14 @@ class SSLConnection(object):
SSL_CTX_load_verify_locations(self._ctx.value, self._ca_certs, None)
if self._ciphers:
try:
SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers)
except openssl_error() as err:
raise_ssl_error(ERR_NO_CIPHER, err)
if self._user_config_ssl_ctx:
self._user_config_ssl_ctx(self._intf_ssl_ctx)
def _copy_server(self):
source = self._sock
SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers)
except openssl_error() as err:
raise_ssl_error(ERR_NO_CIPHER, err)
if self._user_config_ssl_ctx:
self._user_config_ssl_ctx(self._intf_ssl_ctx)
def _copy_server(self):
source = self._sock
self._udp_demux = source._udp_demux
rsock = self._udp_demux.get_connection(source._pending_peer_address)
self._ctx = source._ctx
@ -419,16 +419,16 @@ class SSLConnection(object):
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio
new_source_rbio = new_source_wbio
BIO_dgram_set_connected(self._wbio.value,
source._pending_peer_address)
source._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(source._ssl.value)
SSL_set_accept_state(source._ssl.value)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
source._rbio = new_source_rbio
source._wbio = new_source_wbio
SSL_set_bio(source._ssl.value,
BIO_dgram_set_connected(self._wbio.value,
source._pending_peer_address)
source._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(source._ssl.value)
SSL_set_accept_state(source._ssl.value)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
source._rbio = new_source_rbio
source._wbio = new_source_wbio
SSL_set_bio(source._ssl.value,
new_source_rbio.value,
new_source_wbio.value)
new_source_rbio.disown()
@ -441,16 +441,16 @@ class SSLConnection(object):
self._rsock = source._rsock
self._ctx = source._ctx
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
BIO_dgram_set_peer(self._wbio.value, source._peer_address)
self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_accept_state(self._ssl.value)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
if self._do_handshake_on_connect:
return lambda: self.do_handshake()
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
BIO_dgram_set_peer(self._wbio.value, source._peer_address)
self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_accept_state(self._ssl.value)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
if self._do_handshake_on_connect:
return lambda: self.do_handshake()
def _check_nbio(self):
timeout = self._sock.gettimeout()
if self._wbio_nb != timeout is not None:
@ -502,17 +502,17 @@ class SSLConnection(object):
def _verify_cookie_cb(self, ssl, cookie):
if self._get_cookie(ssl) != cookie:
raise Exception("DTLS cookie mismatch")
def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None,
cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
"""Constructor
Arguments:
def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None,
cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
"""Constructor
Arguments:
these arguments match the ones of the SSLSocket class in the
standard library's ssl module
"""
@ -528,46 +528,46 @@ class SSLConnection(object):
ciphers = "DEFAULT"
self._sock = sock
self._keyfile = keyfile
self._certfile = certfile
self._cert_reqs = cert_reqs
self._ssl_version = ssl_version
self._ca_certs = ca_certs
self._do_handshake_on_connect = do_handshake_on_connect
self._suppress_ragged_eofs = suppress_ragged_eofs
self._keyfile = keyfile
self._certfile = certfile
self._cert_reqs = cert_reqs
self._ssl_version = ssl_version
self._ca_certs = ca_certs
self._do_handshake_on_connect = do_handshake_on_connect
self._suppress_ragged_eofs = suppress_ragged_eofs
self._ciphers = ciphers
self._handshake_done = False
self._wbio_nb = self._rbio_nb = False
self._user_config_ssl_ctx = cb_user_config_ssl_ctx
self._intf_ssl_ctx = None
self._user_config_ssl = cb_user_config_ssl
self._intf_ssl = None
if isinstance(sock, SSLConnection):
post_init = self._copy_server()
elif isinstance(sock, _UnwrappedSocket):
self._handshake_done = False
self._wbio_nb = self._rbio_nb = False
self._user_config_ssl_ctx = cb_user_config_ssl_ctx
self._intf_ssl_ctx = None
self._user_config_ssl = cb_user_config_ssl
self._intf_ssl = None
if isinstance(sock, SSLConnection):
post_init = self._copy_server()
elif isinstance(sock, _UnwrappedSocket):
post_init = self._reconnect_unwrapped()
else:
try:
peer_address = sock.getpeername()
except socket.error:
peer_address = None
if server_side:
post_init = self._init_server(peer_address)
else:
post_init = self._init_client(peer_address)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
else:
SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl.value, 1500)
SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value)
self._rbio.disown()
self._wbio.disown()
if post_init:
post_init()
peer_address = None
if server_side:
post_init = self._init_server(peer_address)
else:
post_init = self._init_client(peer_address)
if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
else:
SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl.value, 1500)
SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value)
self._rbio.disown()
self._wbio.disown()
if post_init:
post_init()
def get_socket(self, inbound):
"""Retrieve a socket used by this connection
@ -633,25 +633,25 @@ class SSLConnection(object):
self._ssl.raw)
dtls_peer_address = DTLSv1_listen(self._ssl.value)
except openssl_error() as err:
if err.ssl_error == SSL_ERROR_WANT_READ:
# This method must be called again to forward the next datagram
_logger.debug("DTLSv1_listen must be resumed")
return
elif err.errqueue and err.errqueue[0][0] == ERR_WRONG_VERSION_NUMBER:
_logger.debug("Wrong version number; aborting handshake")
raise
elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH:
_logger.debug("Mismatching cookie received; aborting handshake")
raise
elif err.errqueue and err.errqueue[0][0] == ERR_NO_SHARED_CIPHER:
_logger.debug("No shared cipher; aborting handshake")
raise
_logger.exception("Unexpected error in DTLSv1_listen")
raise
finally:
self._listening = False
self._listening_peer_address = None
if type(peer_address) is tuple:
if err.ssl_error == SSL_ERROR_WANT_READ:
# This method must be called again to forward the next datagram
_logger.debug("DTLSv1_listen must be resumed")
return
elif err.errqueue and err.errqueue[0][0] == ERR_WRONG_VERSION_NUMBER:
_logger.debug("Wrong version number; aborting handshake")
raise
elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH:
_logger.debug("Mismatching cookie received; aborting handshake")
raise
elif err.errqueue and err.errqueue[0][0] == ERR_NO_SHARED_CIPHER:
_logger.debug("No shared cipher; aborting handshake")
raise
_logger.exception("Unexpected error in DTLSv1_listen")
raise
finally:
self._listening = False
self._listening_peer_address = None
if type(peer_address) is tuple:
_logger.debug("New local peer: %s", dtls_peer_address)
self._pending_peer_address = peer_address
else:
@ -672,17 +672,17 @@ class SSLConnection(object):
if not self._pending_peer_address:
if not self.listen():
_logger.debug("Accept returning without connection")
return
new_conn = SSLConnection(self, self._keyfile, self._certfile, True,
self._cert_reqs, self._ssl_version,
self._ca_certs, self._do_handshake_on_connect,
self._suppress_ragged_eofs, self._ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
new_peer = self._pending_peer_address
self._pending_peer_address = None
if self._do_handshake_on_connect:
_logger.debug("Accept returning without connection")
return
new_conn = SSLConnection(self, self._keyfile, self._certfile, True,
self._cert_reqs, self._ssl_version,
self._ca_certs, self._do_handshake_on_connect,
self._suppress_ragged_eofs, self._ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
new_peer = self._pending_peer_address
self._pending_peer_address = None
if self._do_handshake_on_connect:
# Note that since that connection's socket was just created in its
# constructor, the following operation must be blocking; hence
# handshake-on-connect can only be used with a routing demux if
@ -734,48 +734,48 @@ class SSLConnection(object):
Read up to len bytes and return them.
Arguments:
len -- maximum number of bytes to read
Return value:
string containing read bytes
"""
try:
return self._wrap_socket_library_call(
lambda: SSL_read(self._ssl.value, len, buffer), ERR_READ_TIMEOUT)
except openssl_error() as err:
if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err)
raise
def write(self, data):
"""Write data to connection
Write data as string of bytes.
len -- maximum number of bytes to read
Return value:
string containing read bytes
"""
try:
return self._wrap_socket_library_call(
lambda: SSL_read(self._ssl.value, len, buffer), ERR_READ_TIMEOUT)
except openssl_error() as err:
if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err)
raise
def write(self, data):
"""Write data to connection
Write data as string of bytes.
Arguments:
data -- buffer containing data to be written
Return value:
number of bytes actually transmitted
"""
try:
ret = self._wrap_socket_library_call(
lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT)
except openssl_error() as err:
if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err)
raise
if ret:
self._handshake_done = True
return ret
def shutdown(self):
"""Shut down the DTLS connection
This method attemps to complete a bidirectional shutdown between
peers. For non-blocking sockets, it should be called repeatedly until
data -- buffer containing data to be written
Return value:
number of bytes actually transmitted
"""
try:
ret = self._wrap_socket_library_call(
lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT)
except openssl_error() as err:
if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err)
raise
if ret:
self._handshake_done = True
return ret
def shutdown(self):
"""Shut down the DTLS connection
This method attemps to complete a bidirectional shutdown between
peers. For non-blocking sockets, it should be called repeatedly until
it no longer raises continuation request exceptions.
"""
@ -826,33 +826,33 @@ class SSLConnection(object):
return
if binary_form:
return i2d_X509(peer_cert.value)
if self._cert_reqs == CERT_NONE:
return {}
return decode_cert(peer_cert)
peer_certificate = getpeercert # compatibility with _ssl call interface
def getpeercertchain(self, binary_form=False):
try:
stack, num, certs = SSL_get_peer_cert_chain(self._ssl.value)
except openssl_error():
return
peer_cert_chain = [_Rsrc(cert) for cert in certs]
ret = []
if binary_form:
ret = [i2d_X509(x.value) for x in peer_cert_chain]
elif len(peer_cert_chain):
ret = [decode_cert(x) for x in peer_cert_chain]
return ret
def cipher(self):
"""Retrieve information about the current cipher
Return a triple consisting of cipher name, SSL protocol version defining
its use, and the number of secret bits. Return None if handshaking
return i2d_X509(peer_cert.value)
if self._cert_reqs == CERT_NONE:
return {}
return decode_cert(peer_cert)
peer_certificate = getpeercert # compatibility with _ssl call interface
def getpeercertchain(self, binary_form=False):
try:
stack, num, certs = SSL_get_peer_cert_chain(self._ssl.value)
except openssl_error():
return
peer_cert_chain = [_Rsrc(cert) for cert in certs]
ret = []
if binary_form:
ret = [i2d_X509(x.value) for x in peer_cert_chain]
elif len(peer_cert_chain):
ret = [decode_cert(x) for x in peer_cert_chain]
return ret
def cipher(self):
"""Retrieve information about the current cipher
Return a triple consisting of cipher name, SSL protocol version defining
its use, and the number of secret bits. Return None if handshaking
has not been completed.
"""

View File

@ -1,11 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBgzCCASoCCQDdMwvUA/R3lzAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU3WhcNMjcwMzA1MDgzNjU3WjBKMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENB
IEluYzERMA8GA1UEAwwIUmF5Q0FJbmMwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
AASD4xiQkPryjEwUl/GYeGu1CSA3UC6BUY3TiGED3zrC5Bn/POaVVn9GGOQMZUFi
rCkuTgfg/qeIzTrTFndiR5C/MAoGCCqGSM49BAMDA0cAMEQCIHpd9qMvZZV6iaB5
HrmlyfmhIuLBxDQra20Uxl2Y8N64AiAmPKqwPPp7z6IT2AzAXyHCPoVxwWA0NfGx
nmXoYpDFlw==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBgzCCASoCCQDdMwvUA/R3lzAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU3WhcNMjcwMzA1MDgzNjU3WjBKMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENB
IEluYzERMA8GA1UEAwwIUmF5Q0FJbmMwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
AASD4xiQkPryjEwUl/GYeGu1CSA3UC6BUY3TiGED3zrC5Bn/POaVVn9GGOQMZUFi
rCkuTgfg/qeIzTrTFndiR5C/MAoGCCqGSM49BAMDA0cAMEQCIHpd9qMvZZV6iaB5
HrmlyfmhIuLBxDQra20Uxl2Y8N64AiAmPKqwPPp7z6IT2AzAXyHCPoVxwWA0NfGx
nmXoYpDFlw==
-----END CERTIFICATE-----

View File

@ -1,19 +1,19 @@
-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEMWCku4TqKwrQdeECm5LQPCBnr7+cqE4InlRYeObLOxoAoGCCqGSM49
AwEHoUQDQgAEgroFe2fym1V7E3zr/zjuJixpyAjwfig+UTsxxm/04IvXzk2jQCQC
TgbDVohJ8dgh4iEENZv2axWye7XCBzbftQ==
-----END EC PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy
diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI
SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz
A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg
sW3uB0zEuJyqIQ==
-----END CERTIFICATE-----
-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEMWCku4TqKwrQdeECm5LQPCBnr7+cqE4InlRYeObLOxoAoGCCqGSM49
AwEHoUQDQgAEgroFe2fym1V7E3zr/zjuJixpyAjwfig+UTsxxm/04IvXzk2jQCQC
TgbDVohJ8dgh4iEENZv2axWye7XCBzbftQ==
-----END EC PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy
diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI
SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz
A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg
sW3uB0zEuJyqIQ==
-----END CERTIFICATE-----

View File

@ -1,11 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy
diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI
SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz
A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg
sW3uB0zEuJyqIQ==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy
diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI
SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz
A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg
sW3uB0zEuJyqIQ==
-----END CERTIFICATE-----

View File

@ -36,13 +36,13 @@ import socket
from os import path
from logging import basicConfig, DEBUG
basicConfig(level=DEBUG) # set now for dtls import code
from dtls.sslconnection import SSLConnection
from dtls.err import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_ZERO_RETURN
def main():
sck = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sck.bind(("127.0.0.1", 28000))
from dtls.sslconnection import SSLConnection
from dtls.err import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_ZERO_RETURN
def main():
sck = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sck.bind(("127.0.0.1", 28000))
sck.settimeout(30)
cert_path = path.join(path.abspath(path.dirname(__file__)), "certs")
scn = SSLConnection(

View File

@ -1,7 +1,7 @@
HOME = .
RANDFILE = $ENV::HOME/.rnd
[ req ]
HOME = .
RANDFILE = $ENV::HOME/.rnd
[ req ]
distinguished_name = req_distinguished_name
prompt = no

View File

@ -1,7 +1,7 @@
HOME = .
RANDFILE = $ENV::HOME/.rnd
[ req ]
HOME = .
RANDFILE = $ENV::HOME/.rnd
[ req ]
distinguished_name = req_distinguished_name
prompt = no

View File

@ -1,15 +1,15 @@
from os import path
import ssl
from socket import socket, AF_INET, SOCK_DGRAM, SHUT_RDWR
from logging import basicConfig, DEBUG
basicConfig(level=DEBUG) # set now for dtls import code
from dtls import do_patch
do_patch()
cert_path = path.join(path.abspath(path.dirname(__file__)), "certs")
sock = ssl.wrap_socket(socket(AF_INET, SOCK_DGRAM), cert_reqs=ssl.CERT_REQUIRED, ca_certs=path.join(cert_path, "ca-cert.pem"))
sock.connect(('localhost', 28000))
sock.send('Hi there')
print sock.recv()
sock.unwrap()
sock.shutdown(SHUT_RDWR)
from os import path
import ssl
from socket import socket, AF_INET, SOCK_DGRAM, SHUT_RDWR
from logging import basicConfig, DEBUG
basicConfig(level=DEBUG) # set now for dtls import code
from dtls import do_patch
do_patch()
cert_path = path.join(path.abspath(path.dirname(__file__)), "certs")
sock = ssl.wrap_socket(socket(AF_INET, SOCK_DGRAM), cert_reqs=ssl.CERT_REQUIRED, ca_certs=path.join(cert_path, "ca-cert.pem"))
sock.connect(('localhost', 28000))
sock.send('Hi there')
print sock.recv()
sock.unwrap()
sock.shutdown(SHUT_RDWR)

View File

@ -78,12 +78,12 @@ class BasicSocketTests(unittest.TestCase):
def test_constants(self):
ssl.PROTOCOL_SSLv23
ssl.PROTOCOL_TLSv1
ssl.PROTOCOL_DTLSv1 # added
ssl.PROTOCOL_DTLSv1_2 # added
ssl.PROTOCOL_DTLS # added
ssl.CERT_NONE
ssl.CERT_OPTIONAL
ssl.PROTOCOL_TLSv1
ssl.PROTOCOL_DTLSv1 # added
ssl.PROTOCOL_DTLSv1_2 # added
ssl.PROTOCOL_DTLS # added
ssl.CERT_NONE
ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED
def test_dtls_openssl_version(self):
@ -91,22 +91,22 @@ class BasicSocketTests(unittest.TestCase):
t = ssl.DTLS_OPENSSL_VERSION_INFO
s = ssl.DTLS_OPENSSL_VERSION
self.assertIsInstance(n, (int, long))
self.assertIsInstance(t, tuple)
self.assertIsInstance(s, str)
# Some sanity checks follow
# >= 1.0.2
self.assertGreaterEqual(n, 0x10002000)
# < 2.0
self.assertLess(n, 0x20000000)
major, minor, fix, patch, status = t
self.assertIsInstance(t, tuple)
self.assertIsInstance(s, str)
# Some sanity checks follow
# >= 1.0.2
self.assertGreaterEqual(n, 0x10002000)
# < 2.0
self.assertLess(n, 0x20000000)
major, minor, fix, patch, status = t
self.assertGreaterEqual(major, 1)
self.assertLess(major, 2)
self.assertGreaterEqual(minor, 0)
self.assertLess(minor, 256)
self.assertGreaterEqual(fix, 2)
self.assertLess(fix, 256)
self.assertGreaterEqual(patch, 0)
self.assertLessEqual(patch, 26)
self.assertLess(major, 2)
self.assertGreaterEqual(minor, 0)
self.assertLess(minor, 256)
self.assertGreaterEqual(fix, 2)
self.assertLess(fix, 256)
self.assertGreaterEqual(patch, 0)
self.assertLessEqual(patch, 26)
self.assertGreaterEqual(status, 0)
self.assertLessEqual(status, 15)
# Version string as returned by OpenSSL, the format might change
@ -299,37 +299,37 @@ class NetworkedTests(unittest.TestCase):
s.close()
if test_support.verbose:
sys.stdout.write(("\nNeeded %d calls to do_handshake() " +
"to establish session.\n") % count)
def test_get_server_certificate(self):
for prot in (ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLS):
with test_support.transient_internet() as remote:
pem = ssl.get_server_certificate(remote,
prot)
if not pem:
self.fail("No server certificate!")
try:
pem = ssl.get_server_certificate(remote,
prot,
ca_certs=OTHER_CERTFILE)
except ssl.SSLError:
# should fail
pass
else:
self.fail("Got server certificate %s!" % pem)
pem = ssl.get_server_certificate(remote,
prot,
ca_certs=ISSUER_CERTFILE)
if not pem:
self.fail("No server certificate!")
if test_support.verbose:
sys.stdout.write("\nVerified certificate is\n%s\n" % pem)
class ThreadedEchoServer(threading.Thread):
class ConnectionHandler(threading.Thread):
"to establish session.\n") % count)
def test_get_server_certificate(self):
for prot in (ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLS):
with test_support.transient_internet() as remote:
pem = ssl.get_server_certificate(remote,
prot)
if not pem:
self.fail("No server certificate!")
try:
pem = ssl.get_server_certificate(remote,
prot,
ca_certs=OTHER_CERTFILE)
except ssl.SSLError:
# should fail
pass
else:
self.fail("Got server certificate %s!" % pem)
pem = ssl.get_server_certificate(remote,
prot,
ca_certs=ISSUER_CERTFILE)
if not pem:
self.fail("No server certificate!")
if test_support.verbose:
sys.stdout.write("\nVerified certificate is\n%s\n" % pem)
class ThreadedEchoServer(threading.Thread):
class ConnectionHandler(threading.Thread):
"""A mildly complicated class, because we want it to work both
with and without the SSL wrapper around the socket connection, so
@ -532,20 +532,20 @@ class ThreadedEchoServer(threading.Thread):
if acc_ret:
newconn, connaddr = acc_ret
if test_support.verbose and self.chatty:
sys.stdout.write(' server: new connection from '
+ str(connaddr) + '\n')
handler = self.ConnectionHandler(self, newconn)
handler.start()
except socket.timeout:
pass
except ssl.SSLError:
pass
except KeyboardInterrupt:
self.stop()
self.sock.close()
def register_handler(self, add):
with self.num_handlers_lock:
sys.stdout.write(' server: new connection from '
+ str(connaddr) + '\n')
handler = self.ConnectionHandler(self, newconn)
handler.start()
except socket.timeout:
pass
except ssl.SSLError:
pass
except KeyboardInterrupt:
self.stop()
self.sock.close()
def register_handler(self, add):
with self.num_handlers_lock:
if add:
self.num_handlers += 1
else:
@ -1042,40 +1042,40 @@ class ThreadedTests(unittest.TestCase):
"certs", "badkey.pem"))
def test_protocol_dtlsv1(self):
"""Connecting to a DTLSv1 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
# server: 1.0 - client: 1.0 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED)
# server: any - client: 1.0 and 1.2(any) -> ok
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True,
ssl.CERT_REQUIRED)
# server: 1.0 - client: 1.2 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.0 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.2 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
def test_starttls(self):
"""Switching from clear text to encrypted and back again."""
"""Connecting to a DTLSv1 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
# server: 1.0 - client: 1.0 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED)
# server: any - client: 1.0 and 1.2(any) -> ok
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True,
ssl.CERT_REQUIRED)
# server: 1.0 - client: 1.2 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.0 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.2 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
def test_starttls(self):
"""Switching from clear text to encrypted and back again."""
msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS",
"msg 5", "msg 6")
@ -1088,13 +1088,13 @@ class ThreadedTests(unittest.TestCase):
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
wrapped = False
try:
s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM), ssl_version=ssl.PROTOCOL_DTLSv1)
s.connect((HOST, server.port))
s = s.unwrap()
if test_support.verbose:
# try to connect
wrapped = False
try:
s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM), ssl_version=ssl.PROTOCOL_DTLSv1)
s.connect((HOST, server.port))
s = s.unwrap()
if test_support.verbose:
sys.stdout.write("\n")
for indata in msgs:
if test_support.verbose:

File diff suppressed because it is too large Load Diff

View File

@ -54,18 +54,18 @@ class _BIO(_Rsrc):
if self.owned:
_logger.debug("Freeing BIO: %d", self.raw)
from openssl import BIO_free
BIO_free(self._value)
self.owned = False
self._value = None
class _EC_KEY(_Rsrc):
"""EC KEY wrapper"""
def __init__(self, value):
super(_EC_KEY, self).__init__(value)
def __del__(self):
_logger.debug("Freeing EC_KEY: %d", self.raw)
from openssl import EC_KEY_free
EC_KEY_free(self._value)
self._value = None
BIO_free(self._value)
self.owned = False
self._value = None
class _EC_KEY(_Rsrc):
"""EC KEY wrapper"""
def __init__(self, value):
super(_EC_KEY, self).__init__(value)
def __del__(self):
_logger.debug("Freeing EC_KEY: %d", self.raw)
from openssl import EC_KEY_free
EC_KEY_free(self._value)
self._value = None

View File

@ -1,370 +1,370 @@
# -*- encoding: utf-8 -*-
# DTLS Socket: A wrapper for a server and client using a DTLS connection.
# Copyright 2017 Björn Freise
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# The License is also distributed with this work in the file named "LICENSE."
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DTLS Socket
This wrapper encapsulates the state and behavior associated with the connection
between the OpenSSL library and an individual peer when using the DTLS
protocol.
Classes:
DtlsSocket -- DTLS Socket wrapper for use as a client or server
"""
import select
from logging import getLogger
import ssl
import socket
from patch import do_patch
do_patch()
from sslconnection import SSLContext, SSL
import err as err_codes
_logger = getLogger(__name__)
def client(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None,
do_handshake_on_connect=True, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE)
def server(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=False, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options)
class DtlsSocket(object):
class _ClientSession(object):
def __init__(self, host, port, handshake_done=False):
self.host = host
self.port = int(port)
self.handshake_done = handshake_done
def getAddr(self):
return self.host, self.port
def __init__(self,
sock=None,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=None,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=None,
server_key_exchange_curve=None,
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
if server_cert_options is None:
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
self._ssl_logging = False
self._server_side = server_side
self._ciphers = ciphers
self._curves = curves
self._sigalgs = sigalgs
self._user_mtu = user_mtu
self._server_key_exchange_curve = server_key_exchange_curve
self._server_cert_options = server_cert_options
# Default socket creation
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if isinstance(sock, socket.socket):
_sock = sock
self._sock = ssl.wrap_socket(_sock,
keyfile=keyfile,
certfile=certfile,
server_side=self._server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=self._ciphers,
cb_user_config_ssl_ctx=self.user_config_ssl_ctx,
cb_user_config_ssl=self.user_config_ssl)
if self._server_side:
self._clients = {}
self._timeout = None
def __getattr__(self, item):
if hasattr(self, "_sock") and hasattr(self._sock, item):
return getattr(self._sock, item)
raise AttributeError
def user_config_ssl_ctx(self, _ctx):
"""
:param SSLContext _ctx:
"""
_ctx.set_ssl_logging(self._ssl_logging)
if self._ciphers:
_ctx.set_ciphers(self._ciphers)
if self._curves:
_ctx.set_curves(self._curves)
if self._sigalgs:
_ctx.set_sigalgs(self._sigalgs)
if self._server_side:
_ctx.build_cert_chain(flags=self._server_cert_options)
_ctx.set_ecdh_curve(curve_name=self._server_key_exchange_curve)
def user_config_ssl(self, _ssl):
"""
:param SSL _ssl:
"""
if self._user_mtu:
_ssl.set_link_mtu(self._user_mtu)
def settimeout(self, t):
if self._server_side:
self._timeout = t
else:
self._sock.settimeout(t)
def close(self):
if self._server_side:
for cli in self._clients.keys():
cli.close()
else:
try:
self._sock.unwrap()
except:
pass
self._sock.close()
def recvfrom(self, bufsize, flags=0):
if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags)
else:
return self._recvfrom_on_client_side(bufsize, flags=flags)
def _recvfrom_on_server_side(self, bufsize, flags):
try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout:
# __Nothing__ received from any client
raise socket.timeout
try:
for conn in r:
_last_peer = conn.getpeername() if conn._connected else None
if self._sockIsServerSock(conn):
# Connect
self._clientAccept(conn)
else:
# Handshake
if not self._clientHandshakeDone(conn):
self._clientDoHandshake(conn)
# Normal read
else:
buf = self._clientRead(conn, bufsize)
if buf:
if conn in self._clients:
return buf, self._clients[conn].getAddr()
else:
_logger.debug('Received data from an already disconnected client!')
except Exception as e:
setattr(e, 'peer', _last_peer)
raise e
try:
for conn in self._getClientReadingSockets():
if conn.get_timeout():
ret = conn.handle_timeout()
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
except Exception as e:
raise e
# __No_data__ received from any client
raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags):
try:
buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e:
if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e
else:
if buf:
return buf, self._sock.getpeername()
# __No_data__ received from any client
raise socket.timeout
def sendto(self, buf, address):
if self._server_side:
return self._sendto_from_server_side(buf, address)
else:
return self._sendto_from_client_side(buf, address)
def _sendto_from_server_side(self, buf, address):
for conn, client in self._clients.iteritems():
if client.getAddr() == address:
return self._clientWrite(conn, buf)
return 0
def _sendto_from_client_side(self, buf, address):
try:
if not self._sock._connected:
self._sock.connect(address)
bytes_sent = self._sock.send(buf)
except ssl.SSLError as e:
raise e
return bytes_sent
def _getClientReadingSockets(self):
return [x for x in self._clients.keys()]
def _getAllReadingSockets(self):
return [self._sock] + self._getClientReadingSockets()
def _sockIsServerSock(self, conn):
return conn is self._sock
def _clientHandshakeDone(self, conn):
return conn in self._clients and self._clients[conn].handshake_done is True
def _clientAccept(self, conn):
_logger.debug('+' * 60)
ret = None
try:
ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e:
raise e
else:
if ret:
client, addr = ret
host, port = addr
if client in self._clients:
_logger.debug('Client already connected %s' % str(client))
raise ValueError
self._clients[client] = self._ClientSession(host=host, port=port)
self._clientDoHandshake(client)
def _clientDoHandshake(self, conn):
_logger.debug('-' * 60)
conn.setblocking(False)
try:
conn.do_handshake()
_logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True
except ssl.SSLError as e:
if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
self._clientDrop(conn, error=e)
raise e
def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60)
ret = None
try:
ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e:
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
self._clientDrop(conn, error=e)
return ret
def _clientWrite(self, conn, data):
_logger.debug('#' * 60)
ret = None
try:
_data = data
if False:
_data = data.raw
ret = conn.send(_data)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e:
raise e
return ret
def _clientDrop(self, conn, error=None):
_logger.debug('$' * 60)
try:
if error:
_logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error))
else:
_logger.debug('Drop client %s' % str(self._clients[conn].getAddr()))
if conn in self._clients:
del self._clients[conn]
try:
conn.unwrap()
except:
pass
conn.close()
except Exception as e:
pass
# -*- coding: utf-8 -*-
# DTLS Socket: A wrapper for a server and client using a DTLS connection.
# Copyright 2017 Björn Freise
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# The License is also distributed with this work in the file named "LICENSE."
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DTLS Socket
This wrapper encapsulates the state and behavior associated with the connection
between the OpenSSL library and an individual peer when using the DTLS
protocol.
Classes:
DtlsSocket -- DTLS Socket wrapper for use as a client or server
"""
import select
from logging import getLogger
import ssl
import socket
from patch import do_patch
do_patch()
from sslconnection import SSLContext, SSL
import err as err_codes
_logger = getLogger(__name__)
def wrap_client(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None,
do_handshake_on_connect=True, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE)
def wrap_server(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=False, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options)
class DtlsSocket(object):
class _ClientSession(object):
def __init__(self, host, port, handshake_done=False):
self.host = host
self.port = int(port)
self.handshake_done = handshake_done
def getAddr(self):
return self.host, self.port
def __init__(self,
sock=None,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=None,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=None,
server_key_exchange_curve=None,
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
if server_cert_options is None:
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
self._ssl_logging = False
self._server_side = server_side
self._ciphers = ciphers
self._curves = curves
self._sigalgs = sigalgs
self._user_mtu = user_mtu
self._server_key_exchange_curve = server_key_exchange_curve
self._server_cert_options = server_cert_options
# Default socket creation
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if isinstance(sock, socket.socket):
_sock = sock
self._sock = ssl.wrap_socket(_sock,
keyfile=keyfile,
certfile=certfile,
server_side=self._server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=self._ciphers,
cb_user_config_ssl_ctx=self.user_config_ssl_ctx,
cb_user_config_ssl=self.user_config_ssl)
if self._server_side:
self._clients = {}
self._timeout = None
def __getattr__(self, item):
if hasattr(self, "_sock") and hasattr(self._sock, item):
return getattr(self._sock, item)
raise AttributeError
def user_config_ssl_ctx(self, _ctx):
"""
:param SSLContext _ctx:
"""
_ctx.set_ssl_logging(self._ssl_logging)
if self._ciphers:
_ctx.set_ciphers(self._ciphers)
if self._curves:
_ctx.set_curves(self._curves)
if self._sigalgs:
_ctx.set_sigalgs(self._sigalgs)
if self._server_side:
_ctx.build_cert_chain(flags=self._server_cert_options)
_ctx.set_ecdh_curve(curve_name=self._server_key_exchange_curve)
def user_config_ssl(self, _ssl):
"""
:param SSL _ssl:
"""
if self._user_mtu:
_ssl.set_link_mtu(self._user_mtu)
def settimeout(self, t):
if self._server_side:
self._timeout = t
else:
self._sock.settimeout(t)
def close(self):
if self._server_side:
for cli in self._clients.keys():
cli.close()
else:
try:
self._sock.unwrap()
except:
pass
self._sock.close()
def recvfrom(self, bufsize, flags=0):
if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags)
else:
return self._recvfrom_on_client_side(bufsize, flags=flags)
def _recvfrom_on_server_side(self, bufsize, flags):
try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout:
# __Nothing__ received from any client
raise socket.timeout
try:
for conn in r:
_last_peer = conn.getpeername() if conn._connected else None
if self._sockIsServerSock(conn):
# Connect
self._clientAccept(conn)
else:
# Handshake
if not self._clientHandshakeDone(conn):
self._clientDoHandshake(conn)
# Normal read
else:
buf = self._clientRead(conn, bufsize)
if buf:
if conn in self._clients:
return buf, self._clients[conn].getAddr()
else:
_logger.debug('Received data from an already disconnected client!')
except Exception as e:
setattr(e, 'peer', _last_peer)
raise e
try:
for conn in self._getClientReadingSockets():
if conn.get_timeout():
ret = conn.handle_timeout()
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
except Exception as e:
raise e
# __No_data__ received from any client
raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags):
try:
buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e:
if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e
else:
if buf:
return buf, self._sock.getpeername()
# __No_data__ received from any client
raise socket.timeout
def sendto(self, buf, address):
if self._server_side:
return self._sendto_from_server_side(buf, address)
else:
return self._sendto_from_client_side(buf, address)
def _sendto_from_server_side(self, buf, address):
for conn, client in self._clients.iteritems():
if client.getAddr() == address:
return self._clientWrite(conn, buf)
return 0
def _sendto_from_client_side(self, buf, address):
try:
if not self._sock._connected:
self._sock.connect(address)
bytes_sent = self._sock.send(buf)
except ssl.SSLError as e:
raise e
return bytes_sent
def _getClientReadingSockets(self):
return [x for x in self._clients.keys()]
def _getAllReadingSockets(self):
return [self._sock] + self._getClientReadingSockets()
def _sockIsServerSock(self, conn):
return conn is self._sock
def _clientHandshakeDone(self, conn):
return conn in self._clients and self._clients[conn].handshake_done is True
def _clientAccept(self, conn):
_logger.debug('+' * 60)
ret = None
try:
ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e:
raise e
else:
if ret:
client, addr = ret
host, port = addr
if client in self._clients:
_logger.debug('Client already connected %s' % str(client))
raise ValueError
self._clients[client] = self._ClientSession(host=host, port=port)
self._clientDoHandshake(client)
def _clientDoHandshake(self, conn):
_logger.debug('-' * 60)
conn.setblocking(False)
try:
conn.do_handshake()
_logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True
except ssl.SSLError as e:
if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
self._clientDrop(conn, error=e)
raise e
def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60)
ret = None
try:
ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e:
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
self._clientDrop(conn, error=e)
return ret
def _clientWrite(self, conn, data):
_logger.debug('#' * 60)
ret = None
try:
_data = data
if False:
_data = data.raw
ret = conn.send(_data)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e:
raise e
return ret
def _clientDrop(self, conn, error=None):
_logger.debug('$' * 60)
try:
if error:
_logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error))
else:
_logger.debug('Drop client %s' % str(self._clients[conn].getAddr()))
if conn in self._clients:
del self._clients[conn]
try:
conn.unwrap()
except:
pass
conn.close()
except Exception as e:
pass

View File

@ -40,23 +40,23 @@ _logger = getLogger(__name__)
class _X509(_Rsrc):
"""Wrapper for the cryptographic library's X509 resource"""
def __init__(self, value):
super(_X509, self).__init__(value)
def __del__(self):
_logger.debug("Freeing X509: %d", self.raw)
X509_free(self._value)
self._value = None
super(_X509, self).__init__(value)
def __del__(self):
_logger.debug("Freeing X509: %d", self.raw)
X509_free(self._value)
self._value = None
class _STACK(_Rsrc):
"""Wrapper for the cryptographic library's stacks"""
def __init__(self, value):
super(_STACK, self).__init__(value)
def __del__(self):
_logger.debug("Freeing stack: %d", self.raw)
sk_pop_free(self._value)
self._value = None
super(_STACK, self).__init__(value)
def __del__(self):
_logger.debug("Freeing stack: %d", self.raw)
sk_pop_free(self._value)
self._value = None
def decode_cert(cert):
"""Convert an X509 certificate into a Python dictionary