diff --git a/dtls/openssl.py b/dtls/openssl.py index 1e26006..1331a01 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -84,7 +84,6 @@ BIO_NOCLOSE = 0x00 BIO_CLOSE = 0x01 SSLEAY_VERSION = 0 SSL_OP_NO_COMPRESSION = 0x00020000 -SSL_OP_NO_QUERY_MTU = 0x00001000 SSL_VERIFY_NONE = 0x00 SSL_VERIFY_PEER = 0x01 SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 0x02 @@ -98,58 +97,17 @@ SSL_SESS_CACHE_NO_INTERNAL_LOOKUP = 0x0100 SSL_SESS_CACHE_NO_INTERNAL_STORE = 0x0200 SSL_SESS_CACHE_NO_INTERNAL = \ SSL_SESS_CACHE_NO_INTERNAL_LOOKUP | SSL_SESS_CACHE_NO_INTERNAL_STORE -SSL_BUILD_CHAIN_FLAG_UNTRUSTED = 0x1 -SSL_BUILD_CHAIN_FLAG_NO_ROOT = 0x2 -SSL_BUILD_CHAIN_FLAG_CHECK = 0x4 -SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR = 0x8 -SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR = 0x10 SSL_FILE_TYPE_PEM = 1 GEN_DIRNAME = 4 NID_subject_alt_name = 85 CRYPTO_LOCK = 1 -SSL_ST_MASK = 0x0FFF -SSL_ST_CONNECT = 0x1000 -SSL_ST_ACCEPT = 0x2000 -SSL_ST_INIT = (SSL_ST_CONNECT | SSL_ST_ACCEPT) -SSL_ST_BEFORE = 0x4000 -SSL_ST_OK = 0x03 -SSL_ST_RENEGOTIATE = (0x04 | SSL_ST_INIT) -SSL_ST_ERR = 0x05 - -SSL_CB_LOOP = 0x01 -SSL_CB_EXIT = 0x02 -SSL_CB_READ = 0x04 -SSL_CB_WRITE = 0x08 -SSL_CB_ALERT = 0x4000 -SSL_CB_READ_ALERT = (SSL_CB_ALERT | SSL_CB_READ) -SSL_CB_WRITE_ALERT = (SSL_CB_ALERT | SSL_CB_WRITE) -SSL_CB_ACCEPT_LOOP = (SSL_ST_ACCEPT | SSL_CB_LOOP) -SSL_CB_ACCEPT_EXIT = (SSL_ST_ACCEPT | SSL_CB_EXIT) -SSL_CB_CONNECT_LOOP = (SSL_ST_CONNECT | SSL_CB_LOOP) -SSL_CB_CONNECT_EXIT = (SSL_ST_CONNECT | SSL_CB_EXIT) -SSL_CB_HANDSHAKE_START = 0x10 -SSL_CB_HANDSHAKE_DONE = 0x20 - # # Integer constants - internal # SSL_CTRL_SET_SESS_CACHE_MODE = 44 SSL_CTRL_SET_READ_AHEAD = 41 SSL_CTRL_OPTIONS = 32 -SSL_CTRL_CLEAR_OPTIONS = 77 -SSL_CTRL_SET_ECDH_AUTO = 94 -SSL_CTRL_BUILD_CERT_CHAIN = 105 -SSL_CTRL_SET_MTU = 17 -SSL_CTRL_GET_CURVES = 90 -SSL_CTRL_SET_CURVES = 91 -SSL_CTRL_SET_CURVES_LIST = 92 -SSL_CTRL_GET_SHARED_CURVE = 93 -SSL_CTRL_SET_SIGALGS = 97 -SSL_CTRL_SET_SIGALGS_LIST = 98 -SSL_CTRL_SET_CLIENT_SIGALGS = 101 -SSL_CTRL_SET_CLIENT_SIGALGS_LIST = 102 -SSL_CTRL_SET_TMP_ECDH = 4 BIO_CTRL_INFO = 3 BIO_CTRL_DGRAM_SET_CONNECTED = 32 BIO_CTRL_DGRAM_GET_PEER = 46 @@ -161,53 +119,6 @@ DTLS_CTRL_LISTEN = 75 X509_NAME_MAXLEN = 256 GETS_MAXLEN = 2048 - -class _EllipticCurve(object): - _curves = None - - @classmethod - def _load_elliptic_curves(cls, lib): - num_curves = lib.EC_get_builtin_curves(0, 0) - if num_curves > 0: - builtin_curves = create_string_buffer(sizeof(EC_builtin_curve) * num_curves) - lib.EC_get_builtin_curves(builtin_curves, num_curves) - return set(cls.from_nid(lib, c.nid) for c in cast(builtin_curves, POINTER(EC_builtin_curve))[:num_curves]) - return set() - - @classmethod - def _get_elliptic_curves(cls, lib): - if cls._curves is None: - cls._curves = cls._load_elliptic_curves(lib) - return cls._curves - - @classmethod - def from_nid(cls, lib, nid): - return cls(lib, nid, cast(lib.OBJ_nid2sn(nid), c_char_p).value.decode("ascii")) - - def __init__(self, lib, nid, name): - self._lib = lib - self._nid = nid - self.name = name - - def __repr__(self): - return "" % (self.name,) - - def as_EC_KEY(self): - key = self._lib.EC_KEY_new_by_curve_name(self._nid) - return key - - -def get_elliptic_curves(): - return _EllipticCurve._get_elliptic_curves(libcrypto) - - -def get_elliptic_curve(name): - for curve in get_elliptic_curves(): - if curve.name == name: - return curve - raise ValueError("unknown curve name", name) - - # # Parameter data types # @@ -260,11 +171,6 @@ class SSL(FuncParam): super(SSL, self).__init__(value) -class EC_builtin_curve(Structure): - _fields_ = [("nid", c_int), - ("comment", c_char_p)] - - class BIO(FuncParam): def __init__(self, value): super(BIO, self).__init__(value) @@ -566,22 +472,14 @@ _subst = {c_long_parm: c_long} _sigs = {} __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "SSLEAY_VERSION", - "SSL_OP_NO_COMPRESSION", "SSL_OP_NO_QUERY_MTU", + "SSL_OP_NO_COMPRESSION", "SSL_VERIFY_NONE", "SSL_VERIFY_PEER", "SSL_VERIFY_FAIL_IF_NO_PEER_CERT", "SSL_VERIFY_CLIENT_ONCE", "SSL_SESS_CACHE_OFF", "SSL_SESS_CACHE_CLIENT", "SSL_SESS_CACHE_SERVER", "SSL_SESS_CACHE_BOTH", "SSL_SESS_CACHE_NO_AUTO_CLEAR", "SSL_SESS_CACHE_NO_INTERNAL_LOOKUP", "SSL_SESS_CACHE_NO_INTERNAL_STORE", "SSL_SESS_CACHE_NO_INTERNAL", - "SSL_ST_MASK", "SSL_ST_CONNECT", "SSL_ST_ACCEPT", "SSL_ST_INIT", "SSL_ST_BEFORE", "SSL_ST_OK", - "SSL_ST_RENEGOTIATE", "SSL_ST_ERR", "SSL_CB_LOOP", "SSL_CB_EXIT", "SSL_CB_READ", "SSL_CB_WRITE", - "SSL_CB_ALERT", "SSL_CB_READ_ALERT", "SSL_CB_WRITE_ALERT", - "SSL_CB_ACCEPT_LOOP", "SSL_CB_ACCEPT_EXIT", - "SSL_CB_CONNECT_LOOP", "SSL_CB_CONNECT_EXIT", - "SSL_CB_HANDSHAKE_START", "SSL_CB_HANDSHAKE_DONE", "SSL_FILE_TYPE_PEM", - "SSL_BUILD_CHAIN_FLAG_NO_ROOT", - "SSL_state_string_long", "GEN_DIRNAME", "NID_subject_alt_name", "CRYPTO_LOCK", "CRYPTO_set_locking_callback", @@ -591,28 +489,15 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "BIO_dgram_set_connected", "BIO_dgram_get_peer", "BIO_dgram_set_peer", "BIO_set_nbio", - "SSL_CTX_set_ecdh_auto", - "SSL_CTX_build_cert_chain", "SSL_CTX_set_session_cache_mode", "SSL_CTX_set_read_ahead", "SSL_CTX_set_options", - "SSL_set_options", "SSL_clear_options", "SSL_get_options", - "SSL_get1_curves", "SSL_CTX_set1_curves", "SSL_CTX_set1_curves_list", "SSL_set1_curves_list", - "SSL_CTX_set_tmp_ecdh", - "SSL_CTX_set_info_callback", - "SSL_set_mtu", "SSL_read", "SSL_write", "SSL_CTX_set_cookie_cb", - "SSL_set1_client_sigalgs_list", "SSL_set1_client_sigalgs", - "SSL_CTX_set1_client_sigalgs_list", "SSL_CTX_set1_client_sigalgs", - "SSL_set1_sigalgs_list", "SSL_set1_sigalgs", - "SSL_CTX_set1_sigalgs_list", "SSL_CTX_set1_sigalgs", "OBJ_obj2txt", "decode_ASN1_STRING", "ASN1_TIME_print", "X509_get_notAfter", "ASN1_item_d2i", "GENERAL_NAME_print", - "EC_KEY_free", "sk_value", "sk_pop_free", - "get_elliptic_curves", "i2d_X509"] # note: the following map adds to this list map(lambda x: _make_function(*x), ( @@ -626,9 +511,6 @@ map(lambda x: _make_function(*x), ( ("CRYPTO_num_locks", libcrypto, ((c_int, "ret"),)), ("DTLSv1_server_method", libssl, ((DTLSv1Method, "ret"),)), ("DTLSv1_client_method", libssl, ((DTLSv1Method, "ret"),)), - ("DTLSv1_2_client_method", libssl, ((DTLSv1Method, "ret"),)), - ("DTLSv1_2_server_method", libssl, ((DTLSv1Method, "ret"),)), - ("DTLS_server_method", libssl, ((DTLSv1Method, "ret"),)), ("SSL_CTX_new", libssl, ((SSLCTX, "ret"), (DTLSv1Method, "meth"))), ("SSL_CTX_free", libssl, ((None, "ret"), (SSLCTX, "ctx"))), ("SSL_CTX_set_cookie_generate_cb", libssl, @@ -681,8 +563,6 @@ map(lambda x: _make_function(*x), ( ("SSL_CTX_set_verify", libssl, ((None, "ret"), (SSLCTX, "ctx"), (c_int, "mode"), (c_void_p, "verify_callback", 1, None))), - ("SSL_CTX_set_verify_depth", libssl, - ((None, "ret"), (SSLCTX, "ctx"), (c_int, "depth"))), ("SSL_accept", libssl, ((c_int, "ret"), (SSL, "ssl"))), ("SSL_connect", libssl, ((c_int, "ret"), (SSL, "ssl"))), ("SSL_set_connect_state", libssl, ((None, "ret"), (SSL, "ssl"))), @@ -750,22 +630,6 @@ map(lambda x: _make_function(*x), ( ("SSL_CIPHER_get_bits", libssl, ((c_int, "ret"), (SSL_CIPHER, "cipher"), (POINTER(c_int), "alg_bits", 1, None)), True, None), - ("EC_get_builtin_curves", libcrypto, - ((c_int, "ret"), (POINTER(EC_builtin_curve), "r"), (c_int, "nitems"))), - ("EC_KEY_new_by_curve_name", libcrypto, - ((POINTER(c_char), "ret"), (c_int, "nid"))), - ("EC_KEY_free", libcrypto, - ((None, "ret"), (POINTER(c_char), "key")), False), - ("OBJ_nid2sn", libcrypto, - ((c_char_p, "ret"), (c_int, "n"))), - ("EC_curve_nist2nid", libcrypto, - ((c_int, "ret"), (POINTER(c_char), "name")), True, None), - ("EC_curve_nid2nist", libcrypto, - ((c_char_p, "ret"), (c_int, "nid")), True, None), - ("SSL_CTX_set_info_callback", libssl, - ((None, "ret"), (SSLCTX, "ctx"), (c_void_p, "callback")), False), - ("SSL_state_string_long", libssl, - ((c_char_p, "ret"), (SSL, "ssl")), False), )) # @@ -784,15 +648,6 @@ def CRYPTO_set_locking_callback(locking_function): _locking_cb = _rvoid_int_int_charp_int(py_locking_function) _CRYPTO_set_locking_callback(_locking_cb) -def SSL_set_mtu(ssl, mtu): - return _SSL_ctrl(ssl, SSL_CTRL_SET_MTU, mtu, None) - -def SSL_CTX_build_cert_chain(ctx, flags): - return _SSL_CTX_ctrl(ctx, SSL_CTRL_BUILD_CERT_CHAIN, flags, None) - -def SSL_CTX_set_ecdh_auto(ctx, onoff): - return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_ECDH_AUTO, onoff, None) - def SSL_CTX_set_session_cache_mode(ctx, mode): # Returns the previous value of mode _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SESS_CACHE_MODE, mode, None) @@ -803,47 +658,7 @@ def SSL_CTX_set_read_ahead(ctx, m): def SSL_CTX_set_options(ctx, options): # Returns the new option bitmaks after adding the given options - return _SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, options, None) - -def SSL_set_options(ssl, op): - return _SSL_ctrl(ssl, SSL_CTRL_OPTIONS, op, None) - -def SSL_clear_options(ssl, op): - return _SSL_ctrl(ssl, SSL_CTRL_CLEAR_OPTIONS, op, None) - -def SSL_get_options(ssl): - return _SSL_ctrl(ssl, SSL_CTRL_OPTIONS, 0, None) - -def SSL_get1_curves(ssl, s): - _s = cast(s, POINTER(c_int)) - if s: - mem = None - cnt = SSL_get1_curves(ssl, 0) - if cnt >= s: - mem = create_string_buffer(sizeof(POINTER(c_int)) * s) - ret = _SSL_ctrl(ssl, SSL_CTRL_GET_CURVES, s, mem) - return [x for x in cast(mem, POINTER(c_int))[:s]] - else: - return _SSL_ctrl(ssl, SSL_CTRL_GET_CURVES, 0, _s) - -def SSL_CTX_set1_curves(ctx, clist, clistlen): - _curves = (c_int * len(clist))(*clist) - return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CURVES, len(_curves), _curves) - -def SSL_get_shared_curve(ssl, n): - return _SSL_ctrl(ssl, SSL_CTRL_GET_SHARED_CURVE, n, 0) - -def SSL_CTX_set1_curves_list(ctx, s): - _s = cast(s, POINTER(c_char)) - return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CURVES_LIST, 0, _s) - -def SSL_set1_curves_list(ssl, s): - _s = cast(s, POINTER(c_char)) - return _SSL_ctrl(ssl, SSL_CTRL_SET_CURVES_LIST, 0, _s) - -def SSL_CTX_set_tmp_ecdh(ctx, ecdh): - # return 1 on success and 0 on failure - return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TMP_ECDH, 0, ecdh) + _SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, options, None) _rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), POINTER(c_uint)) @@ -930,8 +745,6 @@ def SSL_write(ssl, data): str_data = data elif hasattr(data, "tobytes") and callable(data.tobytes): str_data = data.tobytes() - elif isinstance(data, ctypes.Array): - str_data = data.raw else: str_data = str(data) return _SSL_write(ssl, str_data, len(str_data)) @@ -1014,73 +827,3 @@ def i2d_X509(x509): bio = _BIO(BIO_new(BIO_s_mem())) _i2d_X509_bio(bio.value, x509) return BIO_get_mem_data(bio.value) - -def EC_KEY_free(key): - _EC_KEY_free(cast(key, POINTER(c_char))) - -_rvoid_voidp_int_int = CFUNCTYPE(None, c_void_p, c_int, c_int) - -def SSL_CTX_set_info_callback(ctx, callback): - """ - Set the info callback - - :param callback: The Python callback to use - :return: None - """ - def py_info_callback(ssl, where, ret): - try: - callback(SSL(ssl), where, ret) - except: - pass - return - - global _info_callback - _info_callback = _rvoid_voidp_int_int(py_info_callback) - _SSL_CTX_set_info_callback(ctx, _info_callback) - - -def SSL_state_string_long(ssl): - try: - ret = _SSL_state_string_long(ssl) - except: - pass - return ret - -# sigalgs_list: (Only for DTLS v1.2) -# -# The short or long name values for digests can be used in a string (for example "MD5", "SHA1", "SHA224", "SHA256", "SHA384", "SHA512") -# and the public key algorithm strings "RSA", "DSA" or "ECDSA". -# The use of MD5 as a digest is strongly discouraged due to security weaknesses. -# Example: "ECDSA+SHA256:RSA+SHA256" - -def SSL_CTX_set1_client_sigalgs(ctx, slist, slistlen): - _slist = (c_int * len(slist))(*slist) - return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CLIENT_SIGALGS, len(_slist), _slist) - -def SSL_CTX_set1_client_sigalgs_list(ctx, s): - _s = cast(s, POINTER(c_char)) - return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CLIENT_SIGALGS_LIST, 0, _s) - -def SSL_set1_client_sigalgs(ssl, slist, slistlen): - _slist = (c_int * len(slist))(*slist) - return _SSL_ctrl(ssl, SSL_CTRL_SET_CLIENT_SIGALGS, len(_slist), _slist) - -def SSL_set1_client_sigalgs_list(ssl, s): - _s = cast(s, POINTER(c_char)) - return _SSL_ctrl(ssl, SSL_CTRL_SET_CLIENT_SIGALGS_LIST, 0, _s) - -def SSL_CTX_set1_sigalgs(ctx, slist, slistlen): - _slist = (c_int * len(slist))(*slist) - return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SIGALGS, len(_slist), _slist) - -def SSL_CTX_set1_sigalgs_list(ctx, s): - _s = cast(s, POINTER(c_char)) - return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SIGALGS_LIST, 0, _s) - -def SSL_set1_sigalgs(ssl, slist, slistlen): - _slist = (c_int * len(slist))(*slist) - return _SSL_ctrl(ssl, SSL_CTRL_SET_SIGALGS, len(_slist), _slist) - -def SSL_set1_sigalgs_list(ssl, s): - _s = cast(s, POINTER(c_char)) - return _SSL_ctrl(ssl, SSL_CTRL_SET_SIGALGS_LIST, 0, _s) diff --git a/dtls/patch.py b/dtls/patch.py index f3bfeb1..d7d73d4 100644 --- a/dtls/patch.py +++ b/dtls/patch.py @@ -34,18 +34,16 @@ has the following effects: PROTOCOL_DTLSv1 for the parameter ssl_version is supported """ -from socket import socket, getaddrinfo, _delegate_methods, error as socket_error -from socket import AF_INET, SOCK_DGRAM -from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE -from types import MethodType, BuiltinMethodType +from socket import SOCK_DGRAM, socket, _delegate_methods, error as socket_error +from socket import AF_INET, SOCK_DGRAM, getaddrinfo +from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE +from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION +from sslconnection import DTLS_OPENSSL_VERSION_INFO +from err import raise_as_ssl_module_error +from types import MethodType from weakref import proxy import errno -from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2 -from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO -from err import raise_as_ssl_module_error - - def do_patch(): import ssl as _ssl # import to be avoided if ssl module is never patched global _orig_SSLSocket_init, _orig_get_server_certificate @@ -53,19 +51,8 @@ def do_patch(): ssl = _ssl if hasattr(ssl, "PROTOCOL_DTLSv1"): return - _orig_wrap_socket = ssl.wrap_socket - ssl.wrap_socket = _wrap_socket - SSLSocket_ = ssl.SSLSocket - class SSLSocket(SSLSocket_): - def __getattr__(self, item): - if hasattr(self, "_sslobj") and hasattr(self._sslobj, item): - return getattr(self._sslobj, item) - raise AttributeError - ssl.SSLSocket = SSLSocket ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 - ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2 ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" - ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2" ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO @@ -75,23 +62,8 @@ def do_patch(): ssl.get_server_certificate = _get_server_certificate raise_as_ssl_module_error() - -def _wrap_socket(sock, keyfile=None, certfile=None, - server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_SSLv23, ca_certs=None, - do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None, - cb_user_ssl_ctx_config=None, cb_user_ssl_config=None): - - return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile, - server_side=server_side, cert_reqs=cert_reqs, - ssl_version=ssl_version, ca_certs=ca_certs, - do_handshake_on_connect=do_handshake_on_connect, - suppress_ragged_eofs=suppress_ragged_eofs, - ciphers=ciphers, - cb_user_ssl_ctx_config=cb_user_ssl_ctx_config, - cb_user_ssl_config=cb_user_ssl_config) - +PROTOCOL_SSLv3 = 1 +PROTOCOL_SSLv23 = 2 def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): """Retrieve a server certificate @@ -102,10 +74,10 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): If 'ssl_version' is specified, use it in the connection attempt. """ - if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2): + if ssl_version != PROTOCOL_DTLSv1: return _orig_get_server_certificate(addr, ssl_version, ca_certs) - if ca_certs is not None: + if (ca_certs is not None): cert_reqs = ssl.CERT_REQUIRED else: cert_reqs = ssl.CERT_NONE @@ -122,8 +94,7 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None, - cb_user_ssl_ctx_config=None, cb_user_ssl_config=None): + suppress_ragged_eofs=True, ciphers=None): is_connection = is_datagram = False if isinstance(sock, SSLConnection): is_connection = True @@ -167,8 +138,7 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None, server_side, cert_reqs, ssl_version, ca_certs, do_handshake_on_connect, - suppress_ragged_eofs, ciphers, - cb_user_ssl_ctx_config, cb_user_ssl_config) + suppress_ragged_eofs, ciphers) else: self._sslobj = sock @@ -180,8 +150,6 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None, self.ciphers = ciphers self.do_handshake_on_connect = do_handshake_on_connect self.suppress_ragged_eofs = suppress_ragged_eofs - self.cb_user_ssl_ctx_config = cb_user_ssl_ctx_config - self.cb_user_ssl_config = cb_user_ssl_config self._makefile_refs = 0 # Perform method substitution and addition (without reference cycle) @@ -201,8 +169,7 @@ def _SSLSocket_listen(self, ignored): self.cert_reqs, self.ssl_version, self.ca_certs, self.do_handshake_on_connect, - self.suppress_ragged_eofs, self.ciphers, - self.cb_user_ssl_ctx_config, self.cb_user_ssl_config) + self.suppress_ragged_eofs, self.ciphers) def _SSLSocket_accept(self): if self._connected: @@ -217,8 +184,7 @@ def _SSLSocket_accept(self): self.cert_reqs, self.ssl_version, self.ca_certs, self.do_handshake_on_connect, - self.suppress_ragged_eofs, self.ciphers, - self.cb_user_ssl_ctx_config, self.cb_user_ssl_config) + self.suppress_ragged_eofs, self.ciphers) return new_ssl_sock, addr def _SSLSocket_real_connect(self, addr, return_errno): @@ -229,8 +195,7 @@ def _SSLSocket_real_connect(self, addr, return_errno): self.cert_reqs, self.ssl_version, self.ca_certs, self.do_handshake_on_connect, - self.suppress_ragged_eofs, self.ciphers, - self.cb_user_ssl_ctx_config, self.cb_user_ssl_config) + self.suppress_ragged_eofs, self.ciphers) try: self._sslobj.connect(addr) except socket_error as e: diff --git a/dtls/prebuilt/win32-x86/libeay32.dll b/dtls/prebuilt/win32-x86/libeay32.dll index db98e5b..8f8a254 100644 Binary files a/dtls/prebuilt/win32-x86/libeay32.dll and b/dtls/prebuilt/win32-x86/libeay32.dll differ diff --git a/dtls/prebuilt/win32-x86/ssleay32.dll b/dtls/prebuilt/win32-x86/ssleay32.dll index 965b609..1a70063 100644 Binary files a/dtls/prebuilt/win32-x86/ssleay32.dll and b/dtls/prebuilt/win32-x86/ssleay32.dll differ diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index 20447f4..9a5dcef 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -48,14 +48,13 @@ from logging import getLogger from os import urandom from select import select from weakref import proxy - from err import openssl_error, InvalidSocketError from err import raise_ssl_error -from err import SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_SYSCALL +from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT -from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR +from err import ERR_BOTH_KEY_CERT_FILES_SVR from x509 import _X509, decode_cert from tlock import tlock_init from openssl import * @@ -64,8 +63,6 @@ from util import _Rsrc, _BIO _logger = getLogger(__name__) PROTOCOL_DTLSv1 = 256 -PROTOCOL_DTLSv1_2 = 258 -PROTOCOL_DTLS = 259 CERT_NONE = 0 CERT_OPTIONAL = 1 CERT_REQUIRED = 2 @@ -125,185 +122,6 @@ class _CallbackProxy(object): return self.ssl_func(self.ssl_connection, *args, **kwargs) -def _ssl_logging_cb(conn, where, return_code): - - def get_alert_desc(return_code): - _alertDesc = { - 0: 'close notify', - 10: 'unexpected message', - 20: 'bad record mac', - 30: 'decompression failure', - 40: 'handshake failure', - 41: 'no certificate', - 42: 'bad certificate', - 43: 'unsupported certificate', - 44: 'certificate revoked', - 45: 'certificate expired', - 46: 'certificate unknown', - 47: 'illegal parameter', - } - _typeDescr = { - 1: 'warning', - 2: 'fatal', - } - - _type = return_code >> 8 - _ret = return_code & 0xFF - - _typeStr = _typeDescr[_type] if _type in _typeDescr else ('unknown (%d)' % _type) - - if _ret in _alertDesc: - return _typeStr, _alertDesc[_ret] - - return _typeStr, str(_ret) - - _state = where & ~SSL_ST_MASK - state = "SSL_undef (%04x)" % _state - - if _state & SSL_ST_INIT == SSL_ST_INIT: - state = "SSL_init" - if _state & SSL_ST_RENEGOTIATE == SSL_ST_RENEGOTIATE: - state = "SSL_renew" - elif _state & SSL_ST_CONNECT: - state = "SSL_connect" - elif _state & SSL_ST_ACCEPT: - state = "SSL_accept" - elif _state & SSL_ST_BEFORE: - state = "SSL_before" - - if where & SSL_CB_LOOP: - state += '_loop' - _logger.debug("%s: %s" % (state, SSL_state_string_long(conn))) - - elif where & SSL_CB_ALERT: - op = "read" if where & SSL_CB_READ else "write" - state += '_alert' - _logger.debug("%s %s: %s" % (state, op, ' - '.join(get_alert_desc(return_code)))) - - elif where & SSL_CB_EXIT: - state += '_exit' - if return_code == 0: - _logger.debug("%s: failed in %s" % (state, SSL_state_string_long(conn))) - elif return_code < 0: - _logger.debug("%s: error %d in %s" % (state, return_code, SSL_state_string_long(conn))) - else: - _logger.debug("%s: %s" % (state, SSL_state_string_long(conn))) - - else: - _logger.debug("%s: %s" % (state, SSL_state_string_long(conn))) - - -class SSLContext(object): - - def __init__(self, ctx): - self._ctx = ctx - - def set_ciphers(self, ciphers): - u''' - s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html - - :param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ... - :return: 1 for success and 0 for failure - ''' - retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers) - return retVal - - def set_sigalgs(self, sigalgs): - u''' - s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html - - :param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ... - :return: 1 for success and 0 for failure - ''' - retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs) - return retVal - - def set_curves(self, curves): - u''' Set supported curves by name, nid or nist. - - :param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ... - :return: 1 for success and 0 for failure - ''' - retVal = None - if isinstance(curves, str): - retVal = SSL_CTX_set1_curves_list(self._ctx, curves) - elif isinstance(curves, tuple): - retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves)) - return retVal - - @staticmethod - def get_ec_nist2nid(nist): - if not isinstance(nist, tuple): - nist = nist.split(":") - nid = tuple(EC_curve_nist2nid(x) for x in nist) - return nid - - @staticmethod - def get_ec_nid2nist(nid): - if not isinstance(nid, tuple): - nid = (nid, ) - nist = ":".join([EC_curve_nid2nist(x) for x in nid]) - return nist - - @staticmethod - def get_ec_available(bAsName=True): - curves = get_elliptic_curves() - return sorted([x.name for x in curves] if bAsName else [x._nid for x in curves]) - - def set_ecdh_curve(self, curve_name=None): - u''' - Used for server only! - - s.a. openssl.exe ecparam -list_curves - - :param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ... - :return: 1 for success and 0 for failure - ''' - if curve_name: - retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0) - avail_curves = get_elliptic_curves() - key = [curve for curve in avail_curves if curve.name == curve_name][0].as_EC_KEY() - retVal = SSL_CTX_set_tmp_ecdh(self._ctx, key) - EC_KEY_free(key) - else: - retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1) - return retVal - - def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NO_ROOT): - u''' - Used for server side only! - - :param flags: - :return: 1 for success and 0 for failure - ''' - retVal = SSL_CTX_build_cert_chain(self._ctx, flags) - return retVal - - def set_ssl_logging(self, enable=False, func=_ssl_logging_cb): - u''' Enable or disable SSL logging - - :param True | False enable: Enable or disable SSL logging - :param func: Callback function for logging - ''' - if enable: - SSL_CTX_set_info_callback(self._ctx, func) - else: - SSL_CTX_set_info_callback(self._ctx, 0) - - -class SSL(object): - - def __init__(self, ssl): - self._ssl = ssl - - def set_mtu(self, mtu=None): - if mtu: - SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU) - SSL_set_mtu(self._ssl, mtu) - else: - SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU) - - class SSLConnection(object): """DTLS peer association @@ -331,12 +149,7 @@ class SSLConnection(object): else: self._rsock = rsock self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) - server_method = DTLS_server_method - if self._ssl_version == PROTOCOL_DTLSv1_2: - server_method = DTLSv1_2_server_method - elif self._ssl_version == PROTOCOL_DTLSv1: - server_method = DTLSv1_server_method - self._ctx = _CTX(SSL_CTX_new(server_method())) + self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) if self._cert_reqs == CERT_NONE: verify_mode = SSL_VERIFY_NONE @@ -366,12 +179,7 @@ class SSLConnection(object): self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._rbio = self._wbio - client_method = DTLSv1_client_method - if self._ssl_version == PROTOCOL_DTLSv1_2: - client_method = DTLSv1_2_client_method - elif self._ssl_version == PROTOCOL_DTLSv1: - client_method = DTLSv1_client_method - self._ctx = _CTX(SSL_CTX_new(client_method())) + self._ctx = _CTX(SSL_CTX_new(DTLSv1_client_method())) if self._cert_reqs == CERT_NONE: verify_mode = SSL_VERIFY_NONE else: @@ -389,9 +197,7 @@ class SSLConnection(object): # corruption when packet loss occurs SSL_CTX_set_options(self._ctx.value, SSL_OP_NO_COMPRESSION) if self._certfile: - # SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile) - SSL_CTX_use_certificate_file(self._ctx.value, self._certfile, - SSL_FILE_TYPE_PEM) + SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile) if self._keyfile: SSL_CTX_use_PrivateKey_file(self._ctx.value, self._keyfile, SSL_FILE_TYPE_PEM) @@ -402,8 +208,6 @@ class SSLConnection(object): SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) except openssl_error() as err: raise_ssl_error(ERR_NO_CIPHER, err) - if self._user_ssl_ctx_config: - self._user_ssl_ctx_config(SSLContext(self._ctx.value)) def _copy_server(self): source = self._sock @@ -437,8 +241,6 @@ class SSLConnection(object): new_source_wbio.value) new_source_rbio.disown() new_source_wbio.disown() - if self._user_ssl_config: - self._user_ssl_config(SSL(source._ssl.value)) def _reconnect_unwrapped(self): source = self._sock @@ -508,11 +310,9 @@ class SSLConnection(object): def __init__(self, sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_DTLS, ca_certs=None, + ssl_version=PROTOCOL_DTLSv1, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None, - cb_user_ssl_ctx_config=None, - cb_user_ssl_config=None): + suppress_ragged_eofs=True, ciphers=None): """Constructor Arguments: @@ -530,13 +330,10 @@ class SSLConnection(object): if not ciphers: ciphers = "DEFAULT" - assert ssl_version in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2) - self._sock = sock self._keyfile = keyfile self._certfile = certfile self._cert_reqs = cert_reqs - self._ssl_version = ssl_version self._ca_certs = ca_certs self._do_handshake_on_connect = do_handshake_on_connect self._suppress_ragged_eofs = suppress_ragged_eofs @@ -544,9 +341,6 @@ class SSLConnection(object): self._handshake_done = False self._wbio_nb = self._rbio_nb = False - self._user_ssl_ctx_config = cb_user_ssl_ctx_config - self._user_ssl_config = cb_user_ssl_config - if isinstance(sock, SSLConnection): post_init = self._copy_server() elif isinstance(sock, _UnwrappedSocket): @@ -564,8 +358,6 @@ class SSLConnection(object): SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) self._rbio.disown() self._wbio.disown() - if self._user_ssl_config: - self._user_ssl_config(SSL(self._ssl.value)) if post_init: post_init() @@ -670,11 +462,9 @@ class SSLConnection(object): _logger.debug("Accept returning without connection") return new_conn = SSLConnection(self, self._keyfile, self._certfile, True, - self._cert_reqs, self._ssl_version, + self._cert_reqs, PROTOCOL_DTLSv1, self._ca_certs, self._do_handshake_on_connect, - self._suppress_ragged_eofs, self._ciphers, - cb_user_ssl_ctx_config=self._user_ssl_ctx_config, - cb_user_ssl_config=self._user_ssl_config) + self._suppress_ragged_eofs, self._ciphers) new_peer = self._pending_peer_address self._pending_peer_address = None if self._do_handshake_on_connect: @@ -750,12 +540,8 @@ class SSLConnection(object): number of bytes actually transmitted """ - retVal = self._wrap_socket_library_call( + return self._wrap_socket_library_call( lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT) - # for client side ... we want to know when the handshake is completed - if self._wbio is self._rbio and retVal > 0: - self._handshake_done = True - return retVal def shutdown(self): """Shut down the DTLS connection diff --git a/dtls/wrapper.py b/dtls/wrapper.py deleted file mode 100644 index 16fa805..0000000 --- a/dtls/wrapper.py +++ /dev/null @@ -1,286 +0,0 @@ -# -*- encoding: utf-8 -*- - -import datetime -import select - -from logging import getLogger - -import ssl -import socket -from dtls import do_patch -do_patch() - -_logger = getLogger(__name__) - - -class _ClientSession(object): - - def __init__(self, host, port, handshake_done=False): - self.host = host - self.port = int(port) - self.handshake_done = handshake_done - - def getAddr(self): - return self.host, self.port - - -class DtlsSocket(object): - - def __init__(self, - host, - port, - keyfile=None, - certfile=None, - server_side=False, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_DTLSv1_2, - ca_certs=None, - do_handshake_on_connect=False, - suppress_ragged_eofs=True, - ciphers=None, - curves=None, - sigalgs=None, - user_mtu=None): - - self._ssl_logging = False - self._peer = (host, int(port)) - self._server_side = server_side - self._ciphers = ciphers - self._curves = curves - self._sigalgs = sigalgs - self._user_mtu = user_mtu - - self._sock = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM), - keyfile=keyfile, - certfile=certfile, - server_side=self._server_side, - cert_reqs=cert_reqs, - ssl_version=ssl_version, - ca_certs=ca_certs, - do_handshake_on_connect=do_handshake_on_connect, - ciphers=self._ciphers, - cb_user_ssl_ctx_config=self.user_ssl_ctx_config, - cb_user_ssl_config=self.user_ssl_config) - - if self._server_side: - self._clients = {} - self._timeout = None - - self._sock.bind(self._peer) - self._sock.listen(0) - else: - self._sock.connect(self._peer) - - def user_ssl_ctx_config(self, _ctx): - _ctx.set_ssl_logging(self._ssl_logging) - if self._ciphers: - _ctx.set_ciphers(self._ciphers) - if self._curves: - _ctx.set_curves(self._curves) - if self._sigalgs: - _ctx.set_sigalgs(self._sigalgs) - if self._server_side: - _ctx.build_cert_chain() - _ctx.set_ecdh_curve() # ("secp256k1") - - def user_ssl_config(self, _ssl): - if self._user_mtu: - _ssl.set_mtu(self._user_mtu) - - def settimeout(self, t): - if self._server_side: - self._timeout = t - else: - self._sock.settimeout(t) - - def close(self): - if self._server_side: - for cli in self._clients.keys(): - cli.close() - else: - self._sock.unwrap() - self._sock.close() - - def recvfrom(self, bufsize, flags=0): - if self._server_side: - return self._recvfrom_on_server_side(bufsize, flags=flags) - else: - return self._recvfrom_on_client_side(bufsize, flags=flags) - - def _recvfrom_on_server_side(self, bufsize, flags): - try: - r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout) - - except socket.timeout as e_timeout: - raise e_timeout - - else: - for conn in r: # type: ssl.SSLSocket - if self._sockIsServerSock(conn): - # Connect - self._clientAccept(conn) - else: - # Handshake - if not self._clientHandshakeDone(conn): - self._clientDoHandshake(conn) - # Normal read - else: - buf = self._clientRead(conn, bufsize) - if buf and conn in self._clients: - return buf, self._clients[conn].getAddr() - - for conn in self._getClientReadingSockets(): - if conn.get_timeout(): - conn.handle_timeout() - - raise socket.timeout - - def _recvfrom_on_client_side(self, bufsize, flags): - try: - buf = self._sock.recv(bufsize, flags) - - except ssl.SSLError as e_ssl: - if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN: - return '', self._peer - elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]: - raise - else: # like in [ssl.SSL_ERROR_WANT_READ, ...] - pass - - else: - if buf: - return buf, self._peer - - raise socket.timeout - - def sendto(self, buf, address): - if self._server_side: - return self._sendto_from_server_side(buf, address) - else: - return self._sendto_from_client_side(buf, address) - - def _sendto_from_server_side(self, buf, address): - for conn, client in self._clients.iteritems(): - if client.getAddr() == address: - return self._clientWrite(conn, buf) - return 0 - - def _sendto_from_client_side(self, buf, address): - while True: - try: - bytes_sent = self._sock.send(buf) - - except ssl.SSLError as e_ssl: - if str(e_ssl).startswith("503:"): - # The write operation timed out - continue - # elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ]: - # # no ciphers available - # if e_ssl.args[1][0][0] in [336081077, ]: - # raise - raise - - else: - if bytes_sent: - break - - return bytes_sent - - def _getClientReadingSockets(self): - return [x for x in self._clients.keys()] - - def _getAllReadingSockets(self): - return [self._sock] + self._getClientReadingSockets() - - def _sockIsServerSock(self, conn): - return conn is self._sock - - def _clientHandshakeDone(self, conn): - return conn in self._clients and self._clients[conn].handshake_done is True - - def _clientAccept(self, conn): - _logger.debug('+' * 60) - ret = None - - try: - ret = conn.accept() - _logger.debug('Accept returned with ... %s' % (str(ret))) - - except Exception as e_accept: - pass - - else: - if ret: - client, addr = ret - host, port = addr - if client in self._clients: - raise - self._clients[client] = _ClientSession(host=host, port=port) - - self._clientDoHandshake(client) - - def _clientDoHandshake(self, conn): - _logger.debug('-' * 60) - conn.setblocking(False) - - try: - conn.do_handshake() - _logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr()))) - - self._clients[conn].handshake_done = True - - except ssl.SSLError as e_handshake: - if str(e_handshake).startswith("504:"): - pass - elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ: - pass - else: - raise e_handshake - - def _clientRead(self, conn, bufsize=4096): - _logger.debug('*' * 60) - ret = None - - try: - ret = conn.recv(bufsize) - _logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret)))) - - except ssl.SSLError as e_read: - if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN: - self._clientDrop(conn) - elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]: - self._clientDrop(conn, error=e_read) - else: # like in [ssl.SSL_ERROR_WANT_READ, ...] - pass - - return ret - - def _clientWrite(self, conn, data): - _logger.debug('#' * 60) - ret = None - - try: - ret = conn.send(data.raw) - _logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret))) - - except Exception as e_write: - raise - - return ret - - def _clientDrop(self, conn, error=None): - _logger.debug('$' * 60) - - try: - if error: - _logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error)) - else: - _logger.debug('Drop client %s' % str(self._clients[conn].getAddr())) - - if conn in self._clients: - del self._clients[conn] - conn.unwrap() - conn.close() - - except Exception as e_drop: - pass