Revert "Added supoort for EC-Functions and DTLSv1.2 (Updated OpenSSL-Lib to V1.0.2j)"

This reverts commit 5776e38445.
incoming
mcfreis 2017-03-21 08:07:53 +01:00
parent 5776e38445
commit 84d2906a5e
6 changed files with 28 additions and 820 deletions

View File

@ -84,7 +84,6 @@ BIO_NOCLOSE = 0x00
BIO_CLOSE = 0x01 BIO_CLOSE = 0x01
SSLEAY_VERSION = 0 SSLEAY_VERSION = 0
SSL_OP_NO_COMPRESSION = 0x00020000 SSL_OP_NO_COMPRESSION = 0x00020000
SSL_OP_NO_QUERY_MTU = 0x00001000
SSL_VERIFY_NONE = 0x00 SSL_VERIFY_NONE = 0x00
SSL_VERIFY_PEER = 0x01 SSL_VERIFY_PEER = 0x01
SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 0x02 SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 0x02
@ -98,58 +97,17 @@ SSL_SESS_CACHE_NO_INTERNAL_LOOKUP = 0x0100
SSL_SESS_CACHE_NO_INTERNAL_STORE = 0x0200 SSL_SESS_CACHE_NO_INTERNAL_STORE = 0x0200
SSL_SESS_CACHE_NO_INTERNAL = \ SSL_SESS_CACHE_NO_INTERNAL = \
SSL_SESS_CACHE_NO_INTERNAL_LOOKUP | SSL_SESS_CACHE_NO_INTERNAL_STORE SSL_SESS_CACHE_NO_INTERNAL_LOOKUP | SSL_SESS_CACHE_NO_INTERNAL_STORE
SSL_BUILD_CHAIN_FLAG_UNTRUSTED = 0x1
SSL_BUILD_CHAIN_FLAG_NO_ROOT = 0x2
SSL_BUILD_CHAIN_FLAG_CHECK = 0x4
SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR = 0x8
SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR = 0x10
SSL_FILE_TYPE_PEM = 1 SSL_FILE_TYPE_PEM = 1
GEN_DIRNAME = 4 GEN_DIRNAME = 4
NID_subject_alt_name = 85 NID_subject_alt_name = 85
CRYPTO_LOCK = 1 CRYPTO_LOCK = 1
SSL_ST_MASK = 0x0FFF
SSL_ST_CONNECT = 0x1000
SSL_ST_ACCEPT = 0x2000
SSL_ST_INIT = (SSL_ST_CONNECT | SSL_ST_ACCEPT)
SSL_ST_BEFORE = 0x4000
SSL_ST_OK = 0x03
SSL_ST_RENEGOTIATE = (0x04 | SSL_ST_INIT)
SSL_ST_ERR = 0x05
SSL_CB_LOOP = 0x01
SSL_CB_EXIT = 0x02
SSL_CB_READ = 0x04
SSL_CB_WRITE = 0x08
SSL_CB_ALERT = 0x4000
SSL_CB_READ_ALERT = (SSL_CB_ALERT | SSL_CB_READ)
SSL_CB_WRITE_ALERT = (SSL_CB_ALERT | SSL_CB_WRITE)
SSL_CB_ACCEPT_LOOP = (SSL_ST_ACCEPT | SSL_CB_LOOP)
SSL_CB_ACCEPT_EXIT = (SSL_ST_ACCEPT | SSL_CB_EXIT)
SSL_CB_CONNECT_LOOP = (SSL_ST_CONNECT | SSL_CB_LOOP)
SSL_CB_CONNECT_EXIT = (SSL_ST_CONNECT | SSL_CB_EXIT)
SSL_CB_HANDSHAKE_START = 0x10
SSL_CB_HANDSHAKE_DONE = 0x20
# #
# Integer constants - internal # Integer constants - internal
# #
SSL_CTRL_SET_SESS_CACHE_MODE = 44 SSL_CTRL_SET_SESS_CACHE_MODE = 44
SSL_CTRL_SET_READ_AHEAD = 41 SSL_CTRL_SET_READ_AHEAD = 41
SSL_CTRL_OPTIONS = 32 SSL_CTRL_OPTIONS = 32
SSL_CTRL_CLEAR_OPTIONS = 77
SSL_CTRL_SET_ECDH_AUTO = 94
SSL_CTRL_BUILD_CERT_CHAIN = 105
SSL_CTRL_SET_MTU = 17
SSL_CTRL_GET_CURVES = 90
SSL_CTRL_SET_CURVES = 91
SSL_CTRL_SET_CURVES_LIST = 92
SSL_CTRL_GET_SHARED_CURVE = 93
SSL_CTRL_SET_SIGALGS = 97
SSL_CTRL_SET_SIGALGS_LIST = 98
SSL_CTRL_SET_CLIENT_SIGALGS = 101
SSL_CTRL_SET_CLIENT_SIGALGS_LIST = 102
SSL_CTRL_SET_TMP_ECDH = 4
BIO_CTRL_INFO = 3 BIO_CTRL_INFO = 3
BIO_CTRL_DGRAM_SET_CONNECTED = 32 BIO_CTRL_DGRAM_SET_CONNECTED = 32
BIO_CTRL_DGRAM_GET_PEER = 46 BIO_CTRL_DGRAM_GET_PEER = 46
@ -161,53 +119,6 @@ DTLS_CTRL_LISTEN = 75
X509_NAME_MAXLEN = 256 X509_NAME_MAXLEN = 256
GETS_MAXLEN = 2048 GETS_MAXLEN = 2048
class _EllipticCurve(object):
_curves = None
@classmethod
def _load_elliptic_curves(cls, lib):
num_curves = lib.EC_get_builtin_curves(0, 0)
if num_curves > 0:
builtin_curves = create_string_buffer(sizeof(EC_builtin_curve) * num_curves)
lib.EC_get_builtin_curves(builtin_curves, num_curves)
return set(cls.from_nid(lib, c.nid) for c in cast(builtin_curves, POINTER(EC_builtin_curve))[:num_curves])
return set()
@classmethod
def _get_elliptic_curves(cls, lib):
if cls._curves is None:
cls._curves = cls._load_elliptic_curves(lib)
return cls._curves
@classmethod
def from_nid(cls, lib, nid):
return cls(lib, nid, cast(lib.OBJ_nid2sn(nid), c_char_p).value.decode("ascii"))
def __init__(self, lib, nid, name):
self._lib = lib
self._nid = nid
self.name = name
def __repr__(self):
return "<Curve %r>" % (self.name,)
def as_EC_KEY(self):
key = self._lib.EC_KEY_new_by_curve_name(self._nid)
return key
def get_elliptic_curves():
return _EllipticCurve._get_elliptic_curves(libcrypto)
def get_elliptic_curve(name):
for curve in get_elliptic_curves():
if curve.name == name:
return curve
raise ValueError("unknown curve name", name)
# #
# Parameter data types # Parameter data types
# #
@ -260,11 +171,6 @@ class SSL(FuncParam):
super(SSL, self).__init__(value) super(SSL, self).__init__(value)
class EC_builtin_curve(Structure):
_fields_ = [("nid", c_int),
("comment", c_char_p)]
class BIO(FuncParam): class BIO(FuncParam):
def __init__(self, value): def __init__(self, value):
super(BIO, self).__init__(value) super(BIO, self).__init__(value)
@ -566,22 +472,14 @@ _subst = {c_long_parm: c_long}
_sigs = {} _sigs = {}
__all__ = ["BIO_NOCLOSE", "BIO_CLOSE", __all__ = ["BIO_NOCLOSE", "BIO_CLOSE",
"SSLEAY_VERSION", "SSLEAY_VERSION",
"SSL_OP_NO_COMPRESSION", "SSL_OP_NO_QUERY_MTU", "SSL_OP_NO_COMPRESSION",
"SSL_VERIFY_NONE", "SSL_VERIFY_PEER", "SSL_VERIFY_NONE", "SSL_VERIFY_PEER",
"SSL_VERIFY_FAIL_IF_NO_PEER_CERT", "SSL_VERIFY_CLIENT_ONCE", "SSL_VERIFY_FAIL_IF_NO_PEER_CERT", "SSL_VERIFY_CLIENT_ONCE",
"SSL_SESS_CACHE_OFF", "SSL_SESS_CACHE_CLIENT", "SSL_SESS_CACHE_OFF", "SSL_SESS_CACHE_CLIENT",
"SSL_SESS_CACHE_SERVER", "SSL_SESS_CACHE_BOTH", "SSL_SESS_CACHE_SERVER", "SSL_SESS_CACHE_BOTH",
"SSL_SESS_CACHE_NO_AUTO_CLEAR", "SSL_SESS_CACHE_NO_INTERNAL_LOOKUP", "SSL_SESS_CACHE_NO_AUTO_CLEAR", "SSL_SESS_CACHE_NO_INTERNAL_LOOKUP",
"SSL_SESS_CACHE_NO_INTERNAL_STORE", "SSL_SESS_CACHE_NO_INTERNAL", "SSL_SESS_CACHE_NO_INTERNAL_STORE", "SSL_SESS_CACHE_NO_INTERNAL",
"SSL_ST_MASK", "SSL_ST_CONNECT", "SSL_ST_ACCEPT", "SSL_ST_INIT", "SSL_ST_BEFORE", "SSL_ST_OK",
"SSL_ST_RENEGOTIATE", "SSL_ST_ERR", "SSL_CB_LOOP", "SSL_CB_EXIT", "SSL_CB_READ", "SSL_CB_WRITE",
"SSL_CB_ALERT", "SSL_CB_READ_ALERT", "SSL_CB_WRITE_ALERT",
"SSL_CB_ACCEPT_LOOP", "SSL_CB_ACCEPT_EXIT",
"SSL_CB_CONNECT_LOOP", "SSL_CB_CONNECT_EXIT",
"SSL_CB_HANDSHAKE_START", "SSL_CB_HANDSHAKE_DONE",
"SSL_FILE_TYPE_PEM", "SSL_FILE_TYPE_PEM",
"SSL_BUILD_CHAIN_FLAG_NO_ROOT",
"SSL_state_string_long",
"GEN_DIRNAME", "NID_subject_alt_name", "GEN_DIRNAME", "NID_subject_alt_name",
"CRYPTO_LOCK", "CRYPTO_LOCK",
"CRYPTO_set_locking_callback", "CRYPTO_set_locking_callback",
@ -591,28 +489,15 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE",
"BIO_dgram_set_connected", "BIO_dgram_set_connected",
"BIO_dgram_get_peer", "BIO_dgram_set_peer", "BIO_dgram_get_peer", "BIO_dgram_set_peer",
"BIO_set_nbio", "BIO_set_nbio",
"SSL_CTX_set_ecdh_auto",
"SSL_CTX_build_cert_chain",
"SSL_CTX_set_session_cache_mode", "SSL_CTX_set_read_ahead", "SSL_CTX_set_session_cache_mode", "SSL_CTX_set_read_ahead",
"SSL_CTX_set_options", "SSL_CTX_set_options",
"SSL_set_options", "SSL_clear_options", "SSL_get_options",
"SSL_get1_curves", "SSL_CTX_set1_curves", "SSL_CTX_set1_curves_list", "SSL_set1_curves_list",
"SSL_CTX_set_tmp_ecdh",
"SSL_CTX_set_info_callback",
"SSL_set_mtu",
"SSL_read", "SSL_write", "SSL_read", "SSL_write",
"SSL_CTX_set_cookie_cb", "SSL_CTX_set_cookie_cb",
"SSL_set1_client_sigalgs_list", "SSL_set1_client_sigalgs",
"SSL_CTX_set1_client_sigalgs_list", "SSL_CTX_set1_client_sigalgs",
"SSL_set1_sigalgs_list", "SSL_set1_sigalgs",
"SSL_CTX_set1_sigalgs_list", "SSL_CTX_set1_sigalgs",
"OBJ_obj2txt", "decode_ASN1_STRING", "ASN1_TIME_print", "OBJ_obj2txt", "decode_ASN1_STRING", "ASN1_TIME_print",
"X509_get_notAfter", "X509_get_notAfter",
"ASN1_item_d2i", "GENERAL_NAME_print", "ASN1_item_d2i", "GENERAL_NAME_print",
"EC_KEY_free",
"sk_value", "sk_value",
"sk_pop_free", "sk_pop_free",
"get_elliptic_curves",
"i2d_X509"] # note: the following map adds to this list "i2d_X509"] # note: the following map adds to this list
map(lambda x: _make_function(*x), ( map(lambda x: _make_function(*x), (
@ -626,9 +511,6 @@ map(lambda x: _make_function(*x), (
("CRYPTO_num_locks", libcrypto, ((c_int, "ret"),)), ("CRYPTO_num_locks", libcrypto, ((c_int, "ret"),)),
("DTLSv1_server_method", libssl, ((DTLSv1Method, "ret"),)), ("DTLSv1_server_method", libssl, ((DTLSv1Method, "ret"),)),
("DTLSv1_client_method", libssl, ((DTLSv1Method, "ret"),)), ("DTLSv1_client_method", libssl, ((DTLSv1Method, "ret"),)),
("DTLSv1_2_client_method", libssl, ((DTLSv1Method, "ret"),)),
("DTLSv1_2_server_method", libssl, ((DTLSv1Method, "ret"),)),
("DTLS_server_method", libssl, ((DTLSv1Method, "ret"),)),
("SSL_CTX_new", libssl, ((SSLCTX, "ret"), (DTLSv1Method, "meth"))), ("SSL_CTX_new", libssl, ((SSLCTX, "ret"), (DTLSv1Method, "meth"))),
("SSL_CTX_free", libssl, ((None, "ret"), (SSLCTX, "ctx"))), ("SSL_CTX_free", libssl, ((None, "ret"), (SSLCTX, "ctx"))),
("SSL_CTX_set_cookie_generate_cb", libssl, ("SSL_CTX_set_cookie_generate_cb", libssl,
@ -681,8 +563,6 @@ map(lambda x: _make_function(*x), (
("SSL_CTX_set_verify", libssl, ("SSL_CTX_set_verify", libssl,
((None, "ret"), (SSLCTX, "ctx"), (c_int, "mode"), ((None, "ret"), (SSLCTX, "ctx"), (c_int, "mode"),
(c_void_p, "verify_callback", 1, None))), (c_void_p, "verify_callback", 1, None))),
("SSL_CTX_set_verify_depth", libssl,
((None, "ret"), (SSLCTX, "ctx"), (c_int, "depth"))),
("SSL_accept", libssl, ((c_int, "ret"), (SSL, "ssl"))), ("SSL_accept", libssl, ((c_int, "ret"), (SSL, "ssl"))),
("SSL_connect", libssl, ((c_int, "ret"), (SSL, "ssl"))), ("SSL_connect", libssl, ((c_int, "ret"), (SSL, "ssl"))),
("SSL_set_connect_state", libssl, ((None, "ret"), (SSL, "ssl"))), ("SSL_set_connect_state", libssl, ((None, "ret"), (SSL, "ssl"))),
@ -750,22 +630,6 @@ map(lambda x: _make_function(*x), (
("SSL_CIPHER_get_bits", libssl, ("SSL_CIPHER_get_bits", libssl,
((c_int, "ret"), (SSL_CIPHER, "cipher"), ((c_int, "ret"), (SSL_CIPHER, "cipher"),
(POINTER(c_int), "alg_bits", 1, None)), True, None), (POINTER(c_int), "alg_bits", 1, None)), True, None),
("EC_get_builtin_curves", libcrypto,
((c_int, "ret"), (POINTER(EC_builtin_curve), "r"), (c_int, "nitems"))),
("EC_KEY_new_by_curve_name", libcrypto,
((POINTER(c_char), "ret"), (c_int, "nid"))),
("EC_KEY_free", libcrypto,
((None, "ret"), (POINTER(c_char), "key")), False),
("OBJ_nid2sn", libcrypto,
((c_char_p, "ret"), (c_int, "n"))),
("EC_curve_nist2nid", libcrypto,
((c_int, "ret"), (POINTER(c_char), "name")), True, None),
("EC_curve_nid2nist", libcrypto,
((c_char_p, "ret"), (c_int, "nid")), True, None),
("SSL_CTX_set_info_callback", libssl,
((None, "ret"), (SSLCTX, "ctx"), (c_void_p, "callback")), False),
("SSL_state_string_long", libssl,
((c_char_p, "ret"), (SSL, "ssl")), False),
)) ))
# #
@ -784,15 +648,6 @@ def CRYPTO_set_locking_callback(locking_function):
_locking_cb = _rvoid_int_int_charp_int(py_locking_function) _locking_cb = _rvoid_int_int_charp_int(py_locking_function)
_CRYPTO_set_locking_callback(_locking_cb) _CRYPTO_set_locking_callback(_locking_cb)
def SSL_set_mtu(ssl, mtu):
return _SSL_ctrl(ssl, SSL_CTRL_SET_MTU, mtu, None)
def SSL_CTX_build_cert_chain(ctx, flags):
return _SSL_CTX_ctrl(ctx, SSL_CTRL_BUILD_CERT_CHAIN, flags, None)
def SSL_CTX_set_ecdh_auto(ctx, onoff):
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_ECDH_AUTO, onoff, None)
def SSL_CTX_set_session_cache_mode(ctx, mode): def SSL_CTX_set_session_cache_mode(ctx, mode):
# Returns the previous value of mode # Returns the previous value of mode
_SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SESS_CACHE_MODE, mode, None) _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SESS_CACHE_MODE, mode, None)
@ -803,47 +658,7 @@ def SSL_CTX_set_read_ahead(ctx, m):
def SSL_CTX_set_options(ctx, options): def SSL_CTX_set_options(ctx, options):
# Returns the new option bitmaks after adding the given options # Returns the new option bitmaks after adding the given options
return _SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, options, None) _SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, options, None)
def SSL_set_options(ssl, op):
return _SSL_ctrl(ssl, SSL_CTRL_OPTIONS, op, None)
def SSL_clear_options(ssl, op):
return _SSL_ctrl(ssl, SSL_CTRL_CLEAR_OPTIONS, op, None)
def SSL_get_options(ssl):
return _SSL_ctrl(ssl, SSL_CTRL_OPTIONS, 0, None)
def SSL_get1_curves(ssl, s):
_s = cast(s, POINTER(c_int))
if s:
mem = None
cnt = SSL_get1_curves(ssl, 0)
if cnt >= s:
mem = create_string_buffer(sizeof(POINTER(c_int)) * s)
ret = _SSL_ctrl(ssl, SSL_CTRL_GET_CURVES, s, mem)
return [x for x in cast(mem, POINTER(c_int))[:s]]
else:
return _SSL_ctrl(ssl, SSL_CTRL_GET_CURVES, 0, _s)
def SSL_CTX_set1_curves(ctx, clist, clistlen):
_curves = (c_int * len(clist))(*clist)
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CURVES, len(_curves), _curves)
def SSL_get_shared_curve(ssl, n):
return _SSL_ctrl(ssl, SSL_CTRL_GET_SHARED_CURVE, n, 0)
def SSL_CTX_set1_curves_list(ctx, s):
_s = cast(s, POINTER(c_char))
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CURVES_LIST, 0, _s)
def SSL_set1_curves_list(ssl, s):
_s = cast(s, POINTER(c_char))
return _SSL_ctrl(ssl, SSL_CTRL_SET_CURVES_LIST, 0, _s)
def SSL_CTX_set_tmp_ecdh(ctx, ecdh):
# return 1 on success and 0 on failure
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TMP_ECDH, 0, ecdh)
_rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), _rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte),
POINTER(c_uint)) POINTER(c_uint))
@ -930,8 +745,6 @@ def SSL_write(ssl, data):
str_data = data str_data = data
elif hasattr(data, "tobytes") and callable(data.tobytes): elif hasattr(data, "tobytes") and callable(data.tobytes):
str_data = data.tobytes() str_data = data.tobytes()
elif isinstance(data, ctypes.Array):
str_data = data.raw
else: else:
str_data = str(data) str_data = str(data)
return _SSL_write(ssl, str_data, len(str_data)) return _SSL_write(ssl, str_data, len(str_data))
@ -1014,73 +827,3 @@ def i2d_X509(x509):
bio = _BIO(BIO_new(BIO_s_mem())) bio = _BIO(BIO_new(BIO_s_mem()))
_i2d_X509_bio(bio.value, x509) _i2d_X509_bio(bio.value, x509)
return BIO_get_mem_data(bio.value) return BIO_get_mem_data(bio.value)
def EC_KEY_free(key):
_EC_KEY_free(cast(key, POINTER(c_char)))
_rvoid_voidp_int_int = CFUNCTYPE(None, c_void_p, c_int, c_int)
def SSL_CTX_set_info_callback(ctx, callback):
"""
Set the info callback
:param callback: The Python callback to use
:return: None
"""
def py_info_callback(ssl, where, ret):
try:
callback(SSL(ssl), where, ret)
except:
pass
return
global _info_callback
_info_callback = _rvoid_voidp_int_int(py_info_callback)
_SSL_CTX_set_info_callback(ctx, _info_callback)
def SSL_state_string_long(ssl):
try:
ret = _SSL_state_string_long(ssl)
except:
pass
return ret
# sigalgs_list: (Only for DTLS v1.2)
#
# The short or long name values for digests can be used in a string (for example "MD5", "SHA1", "SHA224", "SHA256", "SHA384", "SHA512")
# and the public key algorithm strings "RSA", "DSA" or "ECDSA".
# The use of MD5 as a digest is strongly discouraged due to security weaknesses.
# Example: "ECDSA+SHA256:RSA+SHA256"
def SSL_CTX_set1_client_sigalgs(ctx, slist, slistlen):
_slist = (c_int * len(slist))(*slist)
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CLIENT_SIGALGS, len(_slist), _slist)
def SSL_CTX_set1_client_sigalgs_list(ctx, s):
_s = cast(s, POINTER(c_char))
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CLIENT_SIGALGS_LIST, 0, _s)
def SSL_set1_client_sigalgs(ssl, slist, slistlen):
_slist = (c_int * len(slist))(*slist)
return _SSL_ctrl(ssl, SSL_CTRL_SET_CLIENT_SIGALGS, len(_slist), _slist)
def SSL_set1_client_sigalgs_list(ssl, s):
_s = cast(s, POINTER(c_char))
return _SSL_ctrl(ssl, SSL_CTRL_SET_CLIENT_SIGALGS_LIST, 0, _s)
def SSL_CTX_set1_sigalgs(ctx, slist, slistlen):
_slist = (c_int * len(slist))(*slist)
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SIGALGS, len(_slist), _slist)
def SSL_CTX_set1_sigalgs_list(ctx, s):
_s = cast(s, POINTER(c_char))
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SIGALGS_LIST, 0, _s)
def SSL_set1_sigalgs(ssl, slist, slistlen):
_slist = (c_int * len(slist))(*slist)
return _SSL_ctrl(ssl, SSL_CTRL_SET_SIGALGS, len(_slist), _slist)
def SSL_set1_sigalgs_list(ssl, s):
_s = cast(s, POINTER(c_char))
return _SSL_ctrl(ssl, SSL_CTRL_SET_SIGALGS_LIST, 0, _s)

View File

@ -34,18 +34,16 @@ has the following effects:
PROTOCOL_DTLSv1 for the parameter ssl_version is supported PROTOCOL_DTLSv1 for the parameter ssl_version is supported
""" """
from socket import socket, getaddrinfo, _delegate_methods, error as socket_error from socket import SOCK_DGRAM, socket, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_DGRAM from socket import AF_INET, SOCK_DGRAM, getaddrinfo
from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE
from types import MethodType, BuiltinMethodType from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION
from sslconnection import DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error
from types import MethodType
from weakref import proxy from weakref import proxy
import errno import errno
from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error
def do_patch(): def do_patch():
import ssl as _ssl # import to be avoided if ssl module is never patched import ssl as _ssl # import to be avoided if ssl module is never patched
global _orig_SSLSocket_init, _orig_get_server_certificate global _orig_SSLSocket_init, _orig_get_server_certificate
@ -53,19 +51,8 @@ def do_patch():
ssl = _ssl ssl = _ssl
if hasattr(ssl, "PROTOCOL_DTLSv1"): if hasattr(ssl, "PROTOCOL_DTLSv1"):
return return
_orig_wrap_socket = ssl.wrap_socket
ssl.wrap_socket = _wrap_socket
SSLSocket_ = ssl.SSLSocket
class SSLSocket(SSLSocket_):
def __getattr__(self, item):
if hasattr(self, "_sslobj") and hasattr(self._sslobj, item):
return getattr(self._sslobj, item)
raise AttributeError
ssl.SSLSocket = SSLSocket
ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1
ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2"
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
@ -75,23 +62,8 @@ def do_patch():
ssl.get_server_certificate = _get_server_certificate ssl.get_server_certificate = _get_server_certificate
raise_as_ssl_module_error() raise_as_ssl_module_error()
PROTOCOL_SSLv3 = 1
def _wrap_socket(sock, keyfile=None, certfile=None, PROTOCOL_SSLv23 = 2
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None,
cb_user_ssl_ctx_config=None, cb_user_ssl_config=None):
return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers,
cb_user_ssl_ctx_config=cb_user_ssl_ctx_config,
cb_user_ssl_config=cb_user_ssl_config)
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
"""Retrieve a server certificate """Retrieve a server certificate
@ -102,10 +74,10 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
If 'ssl_version' is specified, use it in the connection attempt. If 'ssl_version' is specified, use it in the connection attempt.
""" """
if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2): if ssl_version != PROTOCOL_DTLSv1:
return _orig_get_server_certificate(addr, ssl_version, ca_certs) return _orig_get_server_certificate(addr, ssl_version, ca_certs)
if ca_certs is not None: if (ca_certs is not None):
cert_reqs = ssl.CERT_REQUIRED cert_reqs = ssl.CERT_REQUIRED
else: else:
cert_reqs = ssl.CERT_NONE cert_reqs = ssl.CERT_NONE
@ -122,8 +94,7 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None, ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None, suppress_ragged_eofs=True, ciphers=None):
cb_user_ssl_ctx_config=None, cb_user_ssl_config=None):
is_connection = is_datagram = False is_connection = is_datagram = False
if isinstance(sock, SSLConnection): if isinstance(sock, SSLConnection):
is_connection = True is_connection = True
@ -167,8 +138,7 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None,
server_side, cert_reqs, server_side, cert_reqs,
ssl_version, ca_certs, ssl_version, ca_certs,
do_handshake_on_connect, do_handshake_on_connect,
suppress_ragged_eofs, ciphers, suppress_ragged_eofs, ciphers)
cb_user_ssl_ctx_config, cb_user_ssl_config)
else: else:
self._sslobj = sock self._sslobj = sock
@ -180,8 +150,6 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None,
self.ciphers = ciphers self.ciphers = ciphers
self.do_handshake_on_connect = do_handshake_on_connect self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs self.suppress_ragged_eofs = suppress_ragged_eofs
self.cb_user_ssl_ctx_config = cb_user_ssl_ctx_config
self.cb_user_ssl_config = cb_user_ssl_config
self._makefile_refs = 0 self._makefile_refs = 0
# Perform method substitution and addition (without reference cycle) # Perform method substitution and addition (without reference cycle)
@ -201,8 +169,7 @@ def _SSLSocket_listen(self, ignored):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers, self.suppress_ragged_eofs, self.ciphers)
self.cb_user_ssl_ctx_config, self.cb_user_ssl_config)
def _SSLSocket_accept(self): def _SSLSocket_accept(self):
if self._connected: if self._connected:
@ -217,8 +184,7 @@ def _SSLSocket_accept(self):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers, self.suppress_ragged_eofs, self.ciphers)
self.cb_user_ssl_ctx_config, self.cb_user_ssl_config)
return new_ssl_sock, addr return new_ssl_sock, addr
def _SSLSocket_real_connect(self, addr, return_errno): def _SSLSocket_real_connect(self, addr, return_errno):
@ -229,8 +195,7 @@ def _SSLSocket_real_connect(self, addr, return_errno):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers, self.suppress_ragged_eofs, self.ciphers)
self.cb_user_ssl_ctx_config, self.cb_user_ssl_config)
try: try:
self._sslobj.connect(addr) self._sslobj.connect(addr)
except socket_error as e: except socket_error as e:

Binary file not shown.

Binary file not shown.

View File

@ -48,14 +48,13 @@ from logging import getLogger
from os import urandom from os import urandom
from select import select from select import select
from weakref import proxy from weakref import proxy
from err import openssl_error, InvalidSocketError from err import openssl_error, InvalidSocketError
from err import raise_ssl_error from err import raise_ssl_error
from err import SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_SYSCALL from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR from err import ERR_BOTH_KEY_CERT_FILES_SVR
from x509 import _X509, decode_cert from x509 import _X509, decode_cert
from tlock import tlock_init from tlock import tlock_init
from openssl import * from openssl import *
@ -64,8 +63,6 @@ from util import _Rsrc, _BIO
_logger = getLogger(__name__) _logger = getLogger(__name__)
PROTOCOL_DTLSv1 = 256 PROTOCOL_DTLSv1 = 256
PROTOCOL_DTLSv1_2 = 258
PROTOCOL_DTLS = 259
CERT_NONE = 0 CERT_NONE = 0
CERT_OPTIONAL = 1 CERT_OPTIONAL = 1
CERT_REQUIRED = 2 CERT_REQUIRED = 2
@ -125,185 +122,6 @@ class _CallbackProxy(object):
return self.ssl_func(self.ssl_connection, *args, **kwargs) return self.ssl_func(self.ssl_connection, *args, **kwargs)
def _ssl_logging_cb(conn, where, return_code):
def get_alert_desc(return_code):
_alertDesc = {
0: 'close notify',
10: 'unexpected message',
20: 'bad record mac',
30: 'decompression failure',
40: 'handshake failure',
41: 'no certificate',
42: 'bad certificate',
43: 'unsupported certificate',
44: 'certificate revoked',
45: 'certificate expired',
46: 'certificate unknown',
47: 'illegal parameter',
}
_typeDescr = {
1: 'warning',
2: 'fatal',
}
_type = return_code >> 8
_ret = return_code & 0xFF
_typeStr = _typeDescr[_type] if _type in _typeDescr else ('unknown (%d)' % _type)
if _ret in _alertDesc:
return _typeStr, _alertDesc[_ret]
return _typeStr, str(_ret)
_state = where & ~SSL_ST_MASK
state = "SSL_undef (%04x)" % _state
if _state & SSL_ST_INIT == SSL_ST_INIT:
state = "SSL_init"
if _state & SSL_ST_RENEGOTIATE == SSL_ST_RENEGOTIATE:
state = "SSL_renew"
elif _state & SSL_ST_CONNECT:
state = "SSL_connect"
elif _state & SSL_ST_ACCEPT:
state = "SSL_accept"
elif _state & SSL_ST_BEFORE:
state = "SSL_before"
if where & SSL_CB_LOOP:
state += '_loop'
_logger.debug("%s: %s" % (state, SSL_state_string_long(conn)))
elif where & SSL_CB_ALERT:
op = "read" if where & SSL_CB_READ else "write"
state += '_alert'
_logger.debug("%s %s: %s" % (state, op, ' - '.join(get_alert_desc(return_code))))
elif where & SSL_CB_EXIT:
state += '_exit'
if return_code == 0:
_logger.debug("%s: failed in %s" % (state, SSL_state_string_long(conn)))
elif return_code < 0:
_logger.debug("%s: error %d in %s" % (state, return_code, SSL_state_string_long(conn)))
else:
_logger.debug("%s: %s" % (state, SSL_state_string_long(conn)))
else:
_logger.debug("%s: %s" % (state, SSL_state_string_long(conn)))
class SSLContext(object):
def __init__(self, ctx):
self._ctx = ctx
def set_ciphers(self, ciphers):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html
:param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers)
return retVal
def set_sigalgs(self, sigalgs):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html
:param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs)
return retVal
def set_curves(self, curves):
u''' Set supported curves by name, nid or nist.
:param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ...
:return: 1 for success and 0 for failure
'''
retVal = None
if isinstance(curves, str):
retVal = SSL_CTX_set1_curves_list(self._ctx, curves)
elif isinstance(curves, tuple):
retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves))
return retVal
@staticmethod
def get_ec_nist2nid(nist):
if not isinstance(nist, tuple):
nist = nist.split(":")
nid = tuple(EC_curve_nist2nid(x) for x in nist)
return nid
@staticmethod
def get_ec_nid2nist(nid):
if not isinstance(nid, tuple):
nid = (nid, )
nist = ":".join([EC_curve_nid2nist(x) for x in nid])
return nist
@staticmethod
def get_ec_available(bAsName=True):
curves = get_elliptic_curves()
return sorted([x.name for x in curves] if bAsName else [x._nid for x in curves])
def set_ecdh_curve(self, curve_name=None):
u'''
Used for server only!
s.a. openssl.exe ecparam -list_curves
:param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ...
:return: 1 for success and 0 for failure
'''
if curve_name:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0)
avail_curves = get_elliptic_curves()
key = [curve for curve in avail_curves if curve.name == curve_name][0].as_EC_KEY()
retVal = SSL_CTX_set_tmp_ecdh(self._ctx, key)
EC_KEY_free(key)
else:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1)
return retVal
def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NO_ROOT):
u'''
Used for server side only!
:param flags:
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_build_cert_chain(self._ctx, flags)
return retVal
def set_ssl_logging(self, enable=False, func=_ssl_logging_cb):
u''' Enable or disable SSL logging
:param True | False enable: Enable or disable SSL logging
:param func: Callback function for logging
'''
if enable:
SSL_CTX_set_info_callback(self._ctx, func)
else:
SSL_CTX_set_info_callback(self._ctx, 0)
class SSL(object):
def __init__(self, ssl):
self._ssl = ssl
def set_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
SSL_set_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
class SSLConnection(object): class SSLConnection(object):
"""DTLS peer association """DTLS peer association
@ -331,12 +149,7 @@ class SSLConnection(object):
else: else:
self._rsock = rsock self._rsock = rsock
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
server_method = DTLS_server_method self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method()))
if self._ssl_version == PROTOCOL_DTLSv1_2:
server_method = DTLSv1_2_server_method
elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method()))
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
@ -366,12 +179,7 @@ class SSLConnection(object):
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio self._rbio = self._wbio
client_method = DTLSv1_client_method self._ctx = _CTX(SSL_CTX_new(DTLSv1_client_method()))
if self._ssl_version == PROTOCOL_DTLSv1_2:
client_method = DTLSv1_2_client_method
elif self._ssl_version == PROTOCOL_DTLSv1:
client_method = DTLSv1_client_method
self._ctx = _CTX(SSL_CTX_new(client_method()))
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
else: else:
@ -389,9 +197,7 @@ class SSLConnection(object):
# corruption when packet loss occurs # corruption when packet loss occurs
SSL_CTX_set_options(self._ctx.value, SSL_OP_NO_COMPRESSION) SSL_CTX_set_options(self._ctx.value, SSL_OP_NO_COMPRESSION)
if self._certfile: if self._certfile:
# SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile) SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile)
SSL_CTX_use_certificate_file(self._ctx.value, self._certfile,
SSL_FILE_TYPE_PEM)
if self._keyfile: if self._keyfile:
SSL_CTX_use_PrivateKey_file(self._ctx.value, self._keyfile, SSL_CTX_use_PrivateKey_file(self._ctx.value, self._keyfile,
SSL_FILE_TYPE_PEM) SSL_FILE_TYPE_PEM)
@ -402,8 +208,6 @@ class SSLConnection(object):
SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers)
except openssl_error() as err: except openssl_error() as err:
raise_ssl_error(ERR_NO_CIPHER, err) raise_ssl_error(ERR_NO_CIPHER, err)
if self._user_ssl_ctx_config:
self._user_ssl_ctx_config(SSLContext(self._ctx.value))
def _copy_server(self): def _copy_server(self):
source = self._sock source = self._sock
@ -437,8 +241,6 @@ class SSLConnection(object):
new_source_wbio.value) new_source_wbio.value)
new_source_rbio.disown() new_source_rbio.disown()
new_source_wbio.disown() new_source_wbio.disown()
if self._user_ssl_config:
self._user_ssl_config(SSL(source._ssl.value))
def _reconnect_unwrapped(self): def _reconnect_unwrapped(self):
source = self._sock source = self._sock
@ -508,11 +310,9 @@ class SSLConnection(object):
def __init__(self, sock, keyfile=None, certfile=None, def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None, ssl_version=PROTOCOL_DTLSv1, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None, suppress_ragged_eofs=True, ciphers=None):
cb_user_ssl_ctx_config=None,
cb_user_ssl_config=None):
"""Constructor """Constructor
Arguments: Arguments:
@ -530,13 +330,10 @@ class SSLConnection(object):
if not ciphers: if not ciphers:
ciphers = "DEFAULT" ciphers = "DEFAULT"
assert ssl_version in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2)
self._sock = sock self._sock = sock
self._keyfile = keyfile self._keyfile = keyfile
self._certfile = certfile self._certfile = certfile
self._cert_reqs = cert_reqs self._cert_reqs = cert_reqs
self._ssl_version = ssl_version
self._ca_certs = ca_certs self._ca_certs = ca_certs
self._do_handshake_on_connect = do_handshake_on_connect self._do_handshake_on_connect = do_handshake_on_connect
self._suppress_ragged_eofs = suppress_ragged_eofs self._suppress_ragged_eofs = suppress_ragged_eofs
@ -544,9 +341,6 @@ class SSLConnection(object):
self._handshake_done = False self._handshake_done = False
self._wbio_nb = self._rbio_nb = False self._wbio_nb = self._rbio_nb = False
self._user_ssl_ctx_config = cb_user_ssl_ctx_config
self._user_ssl_config = cb_user_ssl_config
if isinstance(sock, SSLConnection): if isinstance(sock, SSLConnection):
post_init = self._copy_server() post_init = self._copy_server()
elif isinstance(sock, _UnwrappedSocket): elif isinstance(sock, _UnwrappedSocket):
@ -564,8 +358,6 @@ class SSLConnection(object):
SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value)
self._rbio.disown() self._rbio.disown()
self._wbio.disown() self._wbio.disown()
if self._user_ssl_config:
self._user_ssl_config(SSL(self._ssl.value))
if post_init: if post_init:
post_init() post_init()
@ -670,11 +462,9 @@ class SSLConnection(object):
_logger.debug("Accept returning without connection") _logger.debug("Accept returning without connection")
return return
new_conn = SSLConnection(self, self._keyfile, self._certfile, True, new_conn = SSLConnection(self, self._keyfile, self._certfile, True,
self._cert_reqs, self._ssl_version, self._cert_reqs, PROTOCOL_DTLSv1,
self._ca_certs, self._do_handshake_on_connect, self._ca_certs, self._do_handshake_on_connect,
self._suppress_ragged_eofs, self._ciphers, self._suppress_ragged_eofs, self._ciphers)
cb_user_ssl_ctx_config=self._user_ssl_ctx_config,
cb_user_ssl_config=self._user_ssl_config)
new_peer = self._pending_peer_address new_peer = self._pending_peer_address
self._pending_peer_address = None self._pending_peer_address = None
if self._do_handshake_on_connect: if self._do_handshake_on_connect:
@ -750,12 +540,8 @@ class SSLConnection(object):
number of bytes actually transmitted number of bytes actually transmitted
""" """
retVal = self._wrap_socket_library_call( return self._wrap_socket_library_call(
lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT) lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT)
# for client side ... we want to know when the handshake is completed
if self._wbio is self._rbio and retVal > 0:
self._handshake_done = True
return retVal
def shutdown(self): def shutdown(self):
"""Shut down the DTLS connection """Shut down the DTLS connection

View File

@ -1,286 +0,0 @@
# -*- encoding: utf-8 -*-
import datetime
import select
from logging import getLogger
import ssl
import socket
from dtls import do_patch
do_patch()
_logger = getLogger(__name__)
class _ClientSession(object):
def __init__(self, host, port, handshake_done=False):
self.host = host
self.port = int(port)
self.handshake_done = handshake_done
def getAddr(self):
return self.host, self.port
class DtlsSocket(object):
def __init__(self,
host,
port,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=None,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=None):
self._ssl_logging = False
self._peer = (host, int(port))
self._server_side = server_side
self._ciphers = ciphers
self._curves = curves
self._sigalgs = sigalgs
self._user_mtu = user_mtu
self._sock = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=keyfile,
certfile=certfile,
server_side=self._server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
ciphers=self._ciphers,
cb_user_ssl_ctx_config=self.user_ssl_ctx_config,
cb_user_ssl_config=self.user_ssl_config)
if self._server_side:
self._clients = {}
self._timeout = None
self._sock.bind(self._peer)
self._sock.listen(0)
else:
self._sock.connect(self._peer)
def user_ssl_ctx_config(self, _ctx):
_ctx.set_ssl_logging(self._ssl_logging)
if self._ciphers:
_ctx.set_ciphers(self._ciphers)
if self._curves:
_ctx.set_curves(self._curves)
if self._sigalgs:
_ctx.set_sigalgs(self._sigalgs)
if self._server_side:
_ctx.build_cert_chain()
_ctx.set_ecdh_curve() # ("secp256k1")
def user_ssl_config(self, _ssl):
if self._user_mtu:
_ssl.set_mtu(self._user_mtu)
def settimeout(self, t):
if self._server_side:
self._timeout = t
else:
self._sock.settimeout(t)
def close(self):
if self._server_side:
for cli in self._clients.keys():
cli.close()
else:
self._sock.unwrap()
self._sock.close()
def recvfrom(self, bufsize, flags=0):
if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags)
else:
return self._recvfrom_on_client_side(bufsize, flags=flags)
def _recvfrom_on_server_side(self, bufsize, flags):
try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout as e_timeout:
raise e_timeout
else:
for conn in r: # type: ssl.SSLSocket
if self._sockIsServerSock(conn):
# Connect
self._clientAccept(conn)
else:
# Handshake
if not self._clientHandshakeDone(conn):
self._clientDoHandshake(conn)
# Normal read
else:
buf = self._clientRead(conn, bufsize)
if buf and conn in self._clients:
return buf, self._clients[conn].getAddr()
for conn in self._getClientReadingSockets():
if conn.get_timeout():
conn.handle_timeout()
raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags):
try:
buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e_ssl:
if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
return '', self._peer
elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
raise
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
else:
if buf:
return buf, self._peer
raise socket.timeout
def sendto(self, buf, address):
if self._server_side:
return self._sendto_from_server_side(buf, address)
else:
return self._sendto_from_client_side(buf, address)
def _sendto_from_server_side(self, buf, address):
for conn, client in self._clients.iteritems():
if client.getAddr() == address:
return self._clientWrite(conn, buf)
return 0
def _sendto_from_client_side(self, buf, address):
while True:
try:
bytes_sent = self._sock.send(buf)
except ssl.SSLError as e_ssl:
if str(e_ssl).startswith("503:"):
# The write operation timed out
continue
# elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ]:
# # no ciphers available
# if e_ssl.args[1][0][0] in [336081077, ]:
# raise
raise
else:
if bytes_sent:
break
return bytes_sent
def _getClientReadingSockets(self):
return [x for x in self._clients.keys()]
def _getAllReadingSockets(self):
return [self._sock] + self._getClientReadingSockets()
def _sockIsServerSock(self, conn):
return conn is self._sock
def _clientHandshakeDone(self, conn):
return conn in self._clients and self._clients[conn].handshake_done is True
def _clientAccept(self, conn):
_logger.debug('+' * 60)
ret = None
try:
ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e_accept:
pass
else:
if ret:
client, addr = ret
host, port = addr
if client in self._clients:
raise
self._clients[client] = _ClientSession(host=host, port=port)
self._clientDoHandshake(client)
def _clientDoHandshake(self, conn):
_logger.debug('-' * 60)
conn.setblocking(False)
try:
conn.do_handshake()
_logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True
except ssl.SSLError as e_handshake:
if str(e_handshake).startswith("504:"):
pass
elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e_handshake
def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60)
ret = None
try:
ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e_read:
if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
self._clientDrop(conn)
elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
self._clientDrop(conn, error=e_read)
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
return ret
def _clientWrite(self, conn, data):
_logger.debug('#' * 60)
ret = None
try:
ret = conn.send(data.raw)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e_write:
raise
return ret
def _clientDrop(self, conn, error=None):
_logger.debug('$' * 60)
try:
if error:
_logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error))
else:
_logger.debug('Drop client %s' % str(self._clients[conn].getAddr()))
if conn in self._clients:
del self._clients[conn]
conn.unwrap()
conn.close()
except Exception as e_drop:
pass